Skip to content

Commit

Permalink
Flexible symbolic backend support via ParametricMCPs.SymbolicUtils
Browse files Browse the repository at this point in the history
  • Loading branch information
lassepe committed Apr 17, 2024
1 parent d9f6acc commit af21a75
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 40 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@ version = "0.1.1"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ParametricMCPs = "9b992ff8-05bb-4ea1-b9d2-5ef72d82f7ad"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
TrajectoryGamesBase = "ac1ac542-73eb-4349-ae1b-660ab3609574"

[compat]
BlockArrays = "0.16"
ChainRulesCore = "1"
ParametricMCPs = "0.1.5"
Symbolics = "4,5"
ParametricMCPs = "0.1.14"
TrajectoryGamesBase = "0.3.6"
julia = "1.7"
4 changes: 2 additions & 2 deletions src/MCPTrajectoryGameSolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ using TrajectoryGamesBase:
unflatten_trajectory,
unstack_trajectory

using Symbolics: Symbolics
using ParametricMCPs: ParametricMCPs
using ParametricMCPs: ParametricMCPs, SymbolicUtils

using BlockArrays: BlockArrays, mortar, blocks, eachblock
using ChainRulesCore: ChainRulesCore

Expand Down
64 changes: 29 additions & 35 deletions src/solver_setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,19 @@ function Solver(
(; state_blocks, state, control_blocks, control, context = context_dimension, horizon)
end

initial_state_symbolic = let
Symbolics.@variables(x0[1:(dimensions.state)]) |>
only |>
scalarize |>
initial_state_symbolic =
SymbolicUtils.make_variables(symbolic_backend, :x0, dimensions.state) |>
to_blockvector(dimensions.state_blocks)
end

xs_symbolic = let
Symbolics.@variables(X[1:(dimensions.state * horizon)]) |>
only |>
scalarize |>
xs_symbolic =
SymbolicUtils.make_variables(symbolic_backend, :X, dimensions.state * horizon) |>
to_vector_of_blockvectors(dimensions.state_blocks)
end

us_symbolic = let
Symbolics.@variables(U[1:(dimensions.control * horizon)]) |>
only |>
scalarize |>
us_symbolic =
SymbolicUtils.make_variables(symbolic_backend, :U, dimensions.control * horizon) |>
to_vector_of_blockvectors(dimensions.control_blocks)
end

context_symbolic = Symbolics.@variables(context[1:context_dimension]) |> only |> scalarize
context_symbolic = SymbolicUtils.make_variables(symbolic_backend, :context, context_dimension)

cost_per_player_symbolic = game.cost(xs_symbolic, us_symbolic, context_symbolic)

Expand Down Expand Up @@ -90,7 +81,8 @@ function Solver(
end

if isnothing(game.coupling_constraints)
coupling_constraints_symbolic = Symbolics.Num[]
coupling_constraints_symbolic =
SymbolicUtils.make_variables(symbolic_backend, :coupling_constraints, 0)
else
# Note: we don't constraint the first state because we have no control authority over that anyway
coupling_constraints_symbolic =
Expand All @@ -100,19 +92,30 @@ function Solver(
# set up the duals for all constraints
# private constraints
μ_private_symbolic =
Symbolics.@variables(μ[1:length(equality_constraints_symbolic)]) |> only |> scalarize
λ_private_symbolic =
Symbolics.@variables(λ_private[1:length(inequality_constraints_symoblic)]) |>
only |>
scalarize
SymbolicUtils.make_variables(symbolic_backend, , length(equality_constraints_symbolic))

#λ_private_symbolic =
# Symbolics.@variables(λ_private[1:length(inequality_constraints_symoblic)]) |>
# only |>
# scalarize
λ_private_symbolic = SymbolicUtils.make_variables(
symbolic_backend,
:λ_private,
length(inequality_constraints_symoblic),
)

# shared constraints
λ_shared_symbolic =
Symbolics.@variables(λ_shared[1:length(coupling_constraints_symbolic)]) |> only |> scalarize
λ_shared_symbolic = SymbolicUtils.make_variables(
symbolic_backend,
:λ_shared,
length(coupling_constraints_symbolic),
)

# multiplier scaling per player as a runtime parameter
# TODO: technically, we could have this scaling for *every* element of the constraint and
# actually every constraint but for now let's keep it simple
shared_constraint_premultipliers_symbolic =
Symbolics.@variables(γ_scaling[1:num_players(game)]) |> only |> scalarize
SymbolicUtils.make_variables(symbolic_backend, :γ_scaling, num_players(game))

private_variables_per_player_symbolic =
flatten_trajetory_per_player((; xs = xs_symbolic, us = us_symbolic))
Expand All @@ -129,7 +132,7 @@ function Solver(
λ_private_symbolic' * inequality_constraints_symoblic -
λ_shared_symbolic' * coupling_constraints_symbolic * γ_ii

Symbolics.gradient(L_ii, τ_ii)
SymbolicUtils.gradient(L_ii, τ_ii)
end

# set up the full KKT system as an MCP
Expand Down Expand Up @@ -182,12 +185,3 @@ function compose_parameter_vector(; initial_state, context, shared_constraint_pr
[initial_state; context; shared_constraint_premultipliers]
end

"""
Like Symbolics.scalarize but robusutly handle empty arrays.
"""
function scalarize(num)
if length(num) == 0
return Symbolics.Num[]
end
Symbolics.scalarize(num)
end

0 comments on commit af21a75

Please sign in to comment.