diff --git a/src/solve.jl b/src/solve.jl index 7db745b..4ba7b90 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -17,12 +17,17 @@ function TrajectoryGamesBase.solve_trajectory_game!( θ = compose_parameter_vector(; initial_state, context, shared_constraint_premultipliers) + if isnothing(initial_guess) + initial_guess = generate_initial_guess(solver, game, initial_state) + else + initial_guess = (; x₀ = initial_guess.x, y₀ = initial_guess.y) + end + raw_solution = IPMCPs.solve( IPMCPs.InteriorPoint(), solver.mcp_problem_representation, θ; - # initial_guess = isnothing(initial_guess) ? - # generate_initial_guess(solver, game, initial_state) : initial_guess, + initial_guess..., parametric_mcp_solve_options..., ) @@ -52,24 +57,25 @@ function strategy_from_raw_solution(; raw_solution, game, solver) TrajectoryGamesBase.JointStrategy(substrategies, info) end -# function generate_initial_guess(solver, game, initial_state) -# ChainRulesCore.ignore_derivatives() do -# z_initial = zeros(ParametricMCPs.get_problem_size(solver.mcp_problem_representation)) -# -# rollout_strategy = -# map(solver.dimensions.control_blocks) do control_dimension_player_i -# (x, t) -> zeros(control_dimension_player_i) -# end |> TrajectoryGamesBase.JointStrategy -# -# zero_input_trajectory = TrajectoryGamesBase.rollout( -# game.dynamics, -# rollout_strategy, -# initial_state, -# solver.dimensions.horizon, -# ) -# -# copyto!(z_initial, reduce(vcat, flatten_trajetory_per_player(zero_input_trajectory))) -# -# z_initial -# end -# end +function generate_initial_guess(solver, game, initial_state) + ChainRulesCore.ignore_derivatives() do + x_initial = zeros(solver.mcp_problem_representation.unconstrained_dimension) + y_initial = zeros(solver.mcp_problem_representation.constrained_dimension) + + rollout_strategy = + map(solver.dimensions.control_blocks) do control_dimension_player_i + (x, t) -> zeros(control_dimension_player_i) + end |> TrajectoryGamesBase.JointStrategy + + zero_input_trajectory = TrajectoryGamesBase.rollout( + game.dynamics, + rollout_strategy, + initial_state, + solver.dimensions.horizon, + ) + + copyto!(x_initial, reduce(vcat, flatten_trajetory_per_player(zero_input_trajectory))) + + (; x₀ = x_initial, y₀ = y_initial) + end +end diff --git a/test/Demo.jl b/test/Demo.jl index 1057073..b27826e 100644 --- a/test/Demo.jl +++ b/test/Demo.jl @@ -77,14 +77,14 @@ function demo_model_predictive_game_play() # TODO: potentially allow the user to only warm-start the primals and or add noise generate_initial_guess = function (last_strategy, state, time) # only warm-start if the last strategy is converged / feasible - # if !isnothing(last_strategy) && - # # last_strategy.info.raw_solution.status == ParametricMCPs.PATHSolver.MCP_Solved - # last_strategy.info.raw_solution.status === :solved - # initial_guess = last_strategy.info.raw_solution.z - # else - # nothing - # end - nothing + if !isnothing(last_strategy) && + # last_strategy.info.raw_solution.status == ParametricMCPs.PATHSolver.MCP_Solved + last_strategy.info.raw_solution.status === :solved + initial_guess = + (; last_strategy.info.raw_solution.x, last_strategy.info.raw_solution.y) + else + nothing + end end, )