Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lassepe committed Apr 16, 2024
1 parent cef014f commit d9f6acc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ParametricMCPs = "9b992ff8-05bb-4ea1-b9d2-5ef72d82f7ad"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TrajectoryGamesBase = "ac1ac542-73eb-4349-ae1b-660ab3609574"
TrajectoryGamesExamples = "ff3fa34c-8d8f-519c-b5bc-31760c52507a"
Expand Down
38 changes: 24 additions & 14 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using StatsBase: mean
using Zygote: Zygote
using FiniteDiff: FiniteDiff
using Random: Random
using Symbolics: Symbolics

include("Demo.jl")

Expand Down Expand Up @@ -55,25 +56,27 @@ function isfeasible(env::TrajectoryGamesBase.PolygonEnvironment, trajectory; tol
end |> all
end

function input_sanity(; solver, solver_wrong_context, game, initial_state, context)
function input_sanity(; solver, game, initial_state, context)
@testset "input sanity" begin
@test_throws ArgumentError TrajectoryGamesBase.solve_trajectory_game!(
solver,
game,
initial_state,
)
context_with_wrong_size = [context; 0.0]
@test_throws ArgumentError TrajectoryGamesBase.solve_trajectory_game!(
solver_wrong_context,
solver,
game,
initial_state;
context
context=context_with_wrong_size,
)
multipliers_despite_no_shared_constraints = [1]
@test_throws ArgumentError TrajectoryGamesBase.solve_trajectory_game!(
solver,
game,
initial_state;
context,
shared_constraint_premultipliers=[1]
shared_constraint_premultipliers=multipliers_despite_no_shared_constraints,
)
end
end
Expand Down Expand Up @@ -110,7 +113,7 @@ function backward_pass_sanity(;
game,
initial_state,
rng=Random.MersenneTwister(1),
θs=[randn(rng, 4) for _ in 1:10]
θs=[randn(rng, 4) for _ in 1:10],
)
@testset "backward pass sanity" begin
function loss(θ)
Expand All @@ -119,7 +122,7 @@ function backward_pass_sanity(;
solver,
game,
initial_state;
context=θ
context=θ,
)

sum(strategy.substrategies) do substrategy
Expand All @@ -144,22 +147,29 @@ function main()
context = [0.0, 1.0, 0.0, 1.0]
initial_state = mortar([[1.0, 0, 0, 0], [-1.0, 0, 0, 0]])

local solver, solver_wrong_context
local solver, solver_parallel

@testset "Tests" begin
@testset "solver setup" begin
solver =
MCPTrajectoryGameSolver.Solver(game, horizon; context_dimension=length(context))
solver_wrong_context =
MCPTrajectoryGameSolver.Solver(game, horizon; context_dimension=(length(context) + 1))
# exercise some inner solver options...
solver_parallel = MCPTrajectoryGameSolver.Solver(
game,
horizon;
context_dimension=length(context),
parametric_mcp_options=(; parallel=Symbolics.ShardedForm()),
)
end

@testset "solve" begin
input_sanity(; solver, solver_wrong_context, game, initial_state, context)
strategy =
TrajectoryGamesBase.solve_trajectory_game!(solver, game, initial_state; context)
forward_pass_sanity(; solver, game, initial_state, context, horizon, strategy)
backward_pass_sanity(; solver, game, initial_state)
for solver in [solver, solver_parallel]
input_sanity(; solver, game, initial_state, context)
strategy =
TrajectoryGamesBase.solve_trajectory_game!(solver, game, initial_state; context)
forward_pass_sanity(; solver, game, initial_state, context, horizon, strategy)
backward_pass_sanity(; solver, game, initial_state)
end
end

@testset "integration test" begin
Expand Down

0 comments on commit d9f6acc

Please sign in to comment.