Skip to content

Commit

Permalink
Flexible symbolic backend support via ParametricMCPs.SymbolicUtils (#21)
Browse files Browse the repository at this point in the history
* Flexible symbolic backend support via ParametricMCPs.SymbolicUtils

* Add tests
  • Loading branch information
lassepe authored Apr 17, 2024
1 parent d9f6acc commit 19cac86
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 77 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@ version = "0.1.1"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ParametricMCPs = "9b992ff8-05bb-4ea1-b9d2-5ef72d82f7ad"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
TrajectoryGamesBase = "ac1ac542-73eb-4349-ae1b-660ab3609574"

[compat]
BlockArrays = "0.16"
ChainRulesCore = "1"
ParametricMCPs = "0.1.5"
Symbolics = "4,5"
ParametricMCPs = "0.1.14"
TrajectoryGamesBase = "0.3.6"
julia = "1.7"
4 changes: 2 additions & 2 deletions src/MCPTrajectoryGameSolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ using TrajectoryGamesBase:
unflatten_trajectory,
unstack_trajectory

using Symbolics: Symbolics
using ParametricMCPs: ParametricMCPs
using ParametricMCPs: ParametricMCPs, SymbolicUtils

using BlockArrays: BlockArrays, mortar, blocks, eachblock
using ChainRulesCore: ChainRulesCore

Expand Down
66 changes: 30 additions & 36 deletions src/solver_setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ function Solver(
context_dimension = 0,
compute_sensitivities = true,
parametric_mcp_options = (;),
symbolic_backend = SymbolicUtils.SymbolicsBackend(),
)
dimensions = let
state_blocks =
Expand All @@ -22,28 +23,19 @@ function Solver(
(; state_blocks, state, control_blocks, control, context = context_dimension, horizon)
end

initial_state_symbolic = let
Symbolics.@variables(x0[1:(dimensions.state)]) |>
only |>
scalarize |>
initial_state_symbolic =
SymbolicUtils.make_variables(symbolic_backend, :x0, dimensions.state) |>
to_blockvector(dimensions.state_blocks)
end

xs_symbolic = let
Symbolics.@variables(X[1:(dimensions.state * horizon)]) |>
only |>
scalarize |>
xs_symbolic =
SymbolicUtils.make_variables(symbolic_backend, :X, dimensions.state * horizon) |>
to_vector_of_blockvectors(dimensions.state_blocks)
end

us_symbolic = let
Symbolics.@variables(U[1:(dimensions.control * horizon)]) |>
only |>
scalarize |>
us_symbolic =
SymbolicUtils.make_variables(symbolic_backend, :U, dimensions.control * horizon) |>
to_vector_of_blockvectors(dimensions.control_blocks)
end

context_symbolic = Symbolics.@variables(context[1:context_dimension]) |> only |> scalarize
context_symbolic = SymbolicUtils.make_variables(symbolic_backend, :context, context_dimension)

cost_per_player_symbolic = game.cost(xs_symbolic, us_symbolic, context_symbolic)

Expand Down Expand Up @@ -90,7 +82,8 @@ function Solver(
end

if isnothing(game.coupling_constraints)
coupling_constraints_symbolic = Symbolics.Num[]
coupling_constraints_symbolic =
SymbolicUtils.make_variables(symbolic_backend, :coupling_constraints, 0)
else
# Note: we don't constraint the first state because we have no control authority over that anyway
coupling_constraints_symbolic =
Expand All @@ -100,19 +93,30 @@ function Solver(
# set up the duals for all constraints
# private constraints
μ_private_symbolic =
Symbolics.@variables(μ[1:length(equality_constraints_symbolic)]) |> only |> scalarize
λ_private_symbolic =
Symbolics.@variables(λ_private[1:length(inequality_constraints_symoblic)]) |>
only |>
scalarize
SymbolicUtils.make_variables(symbolic_backend, , length(equality_constraints_symbolic))

#λ_private_symbolic =
# Symbolics.@variables(λ_private[1:length(inequality_constraints_symoblic)]) |>
# only |>
# scalarize
λ_private_symbolic = SymbolicUtils.make_variables(
symbolic_backend,
:λ_private,
length(inequality_constraints_symoblic),
)

# shared constraints
λ_shared_symbolic =
Symbolics.@variables(λ_shared[1:length(coupling_constraints_symbolic)]) |> only |> scalarize
λ_shared_symbolic = SymbolicUtils.make_variables(
symbolic_backend,
:λ_shared,
length(coupling_constraints_symbolic),
)

# multiplier scaling per player as a runtime parameter
# TODO: technically, we could have this scaling for *every* element of the constraint and
# actually every constraint but for now let's keep it simple
shared_constraint_premultipliers_symbolic =
Symbolics.@variables(γ_scaling[1:num_players(game)]) |> only |> scalarize
SymbolicUtils.make_variables(symbolic_backend, :γ_scaling, num_players(game))

private_variables_per_player_symbolic =
flatten_trajetory_per_player((; xs = xs_symbolic, us = us_symbolic))
Expand All @@ -129,7 +133,7 @@ function Solver(
λ_private_symbolic' * inequality_constraints_symoblic -
λ_shared_symbolic' * coupling_constraints_symbolic * γ_ii

Symbolics.gradient(L_ii, τ_ii)
SymbolicUtils.gradient(L_ii, τ_ii)
end

# set up the full KKT system as an MCP
Expand Down Expand Up @@ -181,13 +185,3 @@ end
function compose_parameter_vector(; initial_state, context, shared_constraint_premultipliers)
[initial_state; context; shared_constraint_premultipliers]
end

"""
Like Symbolics.scalarize but robusutly handle empty arrays.
"""
function scalarize(num)
if length(num) == 0
return Symbolics.Num[]
end
Symbolics.scalarize(num)
end
87 changes: 51 additions & 36 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ using Symbolics: Symbolics

include("Demo.jl")

function isfeasible(game::TrajectoryGamesBase.TrajectoryGame, trajectory; tol=1e-4)
function isfeasible(game::TrajectoryGamesBase.TrajectoryGame, trajectory; tol = 1e-4)
isfeasible(game.dynamics, trajectory; tol) &&
isfeasible(game.env, trajectory; tol) &&
all(game.coupling_constraints(trajectory.xs, trajectory.us) .>= 0 - tol)
end

function isfeasible(dynamics::TrajectoryGamesBase.AbstractDynamics, trajectory; tol=1e-4)
function isfeasible(dynamics::TrajectoryGamesBase.AbstractDynamics, trajectory; tol = 1e-4)
dynamics_steps_consistent = all(
map(2:length(trajectory.xs)) do t
residual =
trajectory.xs[t] - dynamics(trajectory.xs[t-1], trajectory.us[t-1], t - 1)
trajectory.xs[t] - dynamics(trajectory.xs[t - 1], trajectory.us[t - 1], t - 1)
sum(abs, residual) < tol
end,
)
Expand All @@ -45,7 +45,7 @@ function isfeasible(dynamics::TrajectoryGamesBase.AbstractDynamics, trajectory;
dynamics_steps_consistent && state_bounds_feasible && control_bounds_feasible
end

function isfeasible(env::TrajectoryGamesBase.PolygonEnvironment, trajectory; tol=1e-4)
function isfeasible(env::TrajectoryGamesBase.PolygonEnvironment, trajectory; tol = 1e-4)
trajectory_per_player = MCPTrajectoryGameSolver.unstack_trajectory(trajectory)

map(enumerate(trajectory_per_player)) do (ii, trajectory)
Expand All @@ -68,20 +68,20 @@ function input_sanity(; solver, game, initial_state, context)
solver,
game,
initial_state;
context=context_with_wrong_size,
context = context_with_wrong_size,
)
multipliers_despite_no_shared_constraints = [1]
@test_throws ArgumentError TrajectoryGamesBase.solve_trajectory_game!(
solver,
game,
initial_state;
context,
shared_constraint_premultipliers=multipliers_despite_no_shared_constraints,
shared_constraint_premultipliers = multipliers_despite_no_shared_constraints,
)
end
end

function forward_pass_sanity(; solver, game, initial_state, context, horizon, strategy, tol=1e-4)
function forward_pass_sanity(; solver, game, initial_state, context, horizon, strategy, tol = 1e-4)
@testset "forwardpass sanity" begin
nash_trajectory =
TrajectoryGamesBase.rollout(game.dynamics, strategy, initial_state, horizon)
Expand Down Expand Up @@ -112,8 +112,8 @@ function backward_pass_sanity(;
solver,
game,
initial_state,
rng=Random.MersenneTwister(1),
θs=[randn(rng, 4) for _ in 1:10],
rng = Random.MersenneTwister(1),
θs = [randn(rng, 4) for _ in 1:10],
)
@testset "backward pass sanity" begin
function loss(θ)
Expand All @@ -122,7 +122,7 @@ function backward_pass_sanity(;
solver,
game,
initial_state;
context=θ,
context = θ,
)

sum(strategy.substrategies) do substrategy
Expand All @@ -136,7 +136,7 @@ function backward_pass_sanity(;
for θ in θs
∇_zygote = Zygote.gradient(loss, θ) |> only
∇_finitediff = FiniteDiff.finite_difference_gradient(loss, θ)
@test isapprox(∇_zygote, ∇_finitediff; atol=1e-4)
@test isapprox(∇_zygote, ∇_finitediff; atol = 1e-4)
end
end
end
Expand All @@ -147,34 +147,49 @@ function main()
context = [0.0, 1.0, 0.0, 1.0]
initial_state = mortar([[1.0, 0, 0, 0], [-1.0, 0, 0, 0]])

local solver, solver_parallel

@testset "Tests" begin
@testset "solver setup" begin
solver =
MCPTrajectoryGameSolver.Solver(game, horizon; context_dimension=length(context))
# exercise some inner solver options...
solver_parallel = MCPTrajectoryGameSolver.Solver(
game,
horizon;
context_dimension=length(context),
parametric_mcp_options=(; parallel=Symbolics.ShardedForm()),
)
end
for options in [
(; symbolic_backend = MCPTrajectoryGameSolver.SymbolicUtils.SymbolicsBackend(),),
(;
symbolic_backend = MCPTrajectoryGameSolver.SymbolicUtils.SymbolicsBackend(),
parametric_mcp_options = (;
backend_options = (; parallel = Symbolics.ShardedForm())
),
),
(;
symbolic_backend = MCPTrajectoryGameSolver.SymbolicUtils.FastDifferentiationBackend(),
),
]
local solver

@testset "$options" begin
@testset "solver setup" begin
solver = nothing
solver = MCPTrajectoryGameSolver.Solver(
game,
horizon;
context_dimension = length(context),
options...,
)
end

@testset "solve" begin
for solver in [solver, solver_parallel]
input_sanity(; solver, game, initial_state, context)
strategy =
TrajectoryGamesBase.solve_trajectory_game!(solver, game, initial_state; context)
forward_pass_sanity(; solver, game, initial_state, context, horizon, strategy)
backward_pass_sanity(; solver, game, initial_state)
end
end
@testset "solve" begin
input_sanity(; solver, game, initial_state, context)
strategy = TrajectoryGamesBase.solve_trajectory_game!(
solver,
game,
initial_state;
context,
)
forward_pass_sanity(; solver, game, initial_state, context, horizon, strategy)
backward_pass_sanity(; solver, game, initial_state)
end

@testset "integration test" begin
Demo.demo_model_predictive_game_play()
Demo.demo_inverse_game()
@testset "integration test" begin
Demo.demo_model_predictive_game_play()
Demo.demo_inverse_game()
end
end
end
end
end
Expand Down

0 comments on commit 19cac86

Please sign in to comment.