Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make CI + tests more efficient #749

Merged
merged 7 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@ permissions:
actions: write
contents: read

# Cancel existing tests on the same PR if a new commit is added to a pull request
concurrency:
group: ${{ github.workflow }}-${{ github.ref || github.run_id }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
test:
runs-on: ${{ matrix.runner.os }}
strategy:
fail-fast: false

matrix:
runner:
# Current stable version
Expand Down Expand Up @@ -58,6 +65,9 @@ jobs:
os: macos-latest
arch: aarch64
num_threads: 2
test_group:
- Group1
- Group2

steps:
- uses: actions/checkout@v4
Expand All @@ -73,6 +83,7 @@ jobs:

- uses: julia-actions/julia-runtest@v1
env:
GROUP: ${{ matrix.test_group }}
JULIA_NUM_THREADS: ${{ matrix.runner.num_threads }}

- uses: julia-actions/julia-processcoverage@v1
Expand Down
2 changes: 1 addition & 1 deletion test/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@
@test retype <: Tuple

# Just make sure the following is runnable.
@test (DynamicPPL.DebugUtils.model_warntype(model); true)
@test DynamicPPL.DebugUtils.model_warntype(model) isa Any
end
end
end
8 changes: 4 additions & 4 deletions test/lkj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ _lkj_atol = 0.05
model = lkj_prior_demo()
# `SampleFromPrior` will sample in constrained space.
@testset "SampleFromPrior" begin
samples = sample(model, SampleFromPrior(), 1_000)
samples = sample(model, SampleFromPrior(), 1_000; progress=false)
@test mean(map(Base.Fix2(getindex, Colon()), samples)) target_mean atol =
_lkj_atol
end

# `SampleFromUniform` will sample in unconstrained space.
@testset "SampleFromUniform" begin
samples = sample(model, SampleFromUniform(), 1_000)
samples = sample(model, SampleFromUniform(), 1_000; progress=false)
@test mean(map(Base.Fix2(getindex, Colon()), samples)) target_mean atol =
_lkj_atol
end
Expand All @@ -39,7 +39,7 @@ end
model = lkj_chol_prior_demo(uplo)
# `SampleFromPrior` will sample in unconstrained space.
@testset "SampleFromPrior" begin
samples = sample(model, SampleFromPrior(), 1_000)
samples = sample(model, SampleFromPrior(), 1_000; progress=false)
# Build correlation matrix from factor
corr_matrices = map(samples) do s
M = reshape(s.metadata.vals, (2, 2))
Expand All @@ -50,7 +50,7 @@ end

# `SampleFromUniform` will sample in unconstrained space.
@testset "SampleFromUniform" begin
samples = sample(model, SampleFromUniform(), 1_000)
samples = sample(model, SampleFromUniform(), 1_000; progress=false)
# Build correlation matrix from factor
corr_matrices = map(samples) do s
M = reshape(s.metadata.vals, (2, 2))
Expand Down
59 changes: 54 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,16 @@ using OrderedCollections: OrderedSet

using DynamicPPL: getargs_dottilde, getargs_tilde, Selector

const GROUP = get(ENV, "GROUP", "All")
Random.seed!(100)

include("test_util.jl")

@testset "DynamicPPL.jl" begin
@testset "interface" begin
@testset verbose = true "DynamicPPL.jl" begin
# The tests are split into two groups so that CI can run in parallel. The
# groups are chosen to make both groups take roughly the same amount of
# time, but beyond that there is no particular reason for the split.
if GROUP == "All" || GROUP == "Group1"
penelopeysm marked this conversation as resolved.
Show resolved Hide resolved
include("utils.jl")
include("compiler.jl")
include("varnamedvector.jl")
Expand All @@ -50,15 +54,60 @@ include("test_util.jl")
include("sampler.jl")
include("independence.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("context_implementations.jl")
include("logdensityfunction.jl")
include("linking.jl")
include("threadsafe.jl")
include("serialization.jl")
include("pointwise_logdensities.jl")
include("lkj.jl")
end

if GROUP == "All" || GROUP == "Group2"
include("contexts.jl")
include("context_implementations.jl")
include("threadsafe.jl")
include("debug_utils.jl")
@testset "compat" begin
include(joinpath("compat", "ad.jl"))
end
@testset "extensions" begin
include("ext/DynamicPPLMCMCChainsExt.jl")
include("ext/DynamicPPLJETExt.jl")
end
@testset "ad" begin
include("ext/DynamicPPLForwardDiffExt.jl")
include("ext/DynamicPPLMooncakeExt.jl")
include("ad.jl")
end
@testset "prob and logprob macro" begin
@test_throws ErrorException prob"..."
@test_throws ErrorException logprob"..."
end
@testset "doctests" begin
DocMeta.setdocmeta!(
DynamicPPL,
:DocTestSetup,
:(using DynamicPPL, Distributions);
recursive=true,
)
doctestfilters = [
# Older versions will show "0 element Array" instead of "Type[]".
r"(Any\[\]|0-element Array{.+,[0-9]+})",
# Older versions will show "Array{...,1}" instead of "Vector{...}".
r"(Array{.+,\s?1}|Vector{.+})",
# Older versions will show "Array{...,2}" instead of "Matrix{...}".
r"(Array{.+,\s?2}|Matrix{.+})",
# Errors from macros sometimes result in `LoadError: LoadError:`
# rather than `LoadError:`, depending on Julia version.
r"ERROR: (LoadError:\s)+",
# Older versions do not have `;;]` but instead just `]` at end of the line
# => need to treat `;;]` and `]` as the same, i.e. ignore them if at the end of a line
r"(;;){0,1}\]$"m,
# Ignore the source of a warning in the doctest output, since this is dependent on host.
# This is a line that starts with "└ @ " and ends with the line number.
r"└ @ .+:[0-9]+",
]
doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters)
end
end

@testset "compat" begin
Expand Down
30 changes: 14 additions & 16 deletions test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,20 @@

@testset "init" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
penelopeysm marked this conversation as resolved.
Show resolved Hide resolved
N = 1000
chain_init = sample(model, SampleFromUniform(), N; progress=false)

for vn in keys(first(chain_init))
if AbstractPPL.subsumes(@varname(s), vn)
# `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2].
dist = InverseGamma(2, 3)
b = DynamicPPL.link_transform(dist)
@test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11
elseif AbstractPPL.subsumes(@varname(m), vn)
# `m ~ Normal(0, sqrt(s))` and its constrained value is the same.
@test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11
else
error("Unknown variable name: $vn")
end
N = 1000
chain_init = sample(model, SampleFromUniform(), N; progress=false)

for vn in keys(first(chain_init))
if AbstractPPL.subsumes(@varname(s), vn)
# `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2].
dist = InverseGamma(2, 3)
b = DynamicPPL.link_transform(dist)
@test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11
elseif AbstractPPL.subsumes(@varname(m), vn)
# `m ~ Normal(0, sqrt(s))` and its constrained value is the same.
@test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11
else
error("Unknown variable name: $vn")
end
end
end
Expand Down
Loading