Skip to content

Commit

Permalink
Surely we don't need the double loop
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Dec 13, 2024
1 parent 9157d06 commit 60f4d33
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,20 @@

@testset "init" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
N = 1000
chain_init = sample(model, SampleFromUniform(), N; progress=false)

for vn in keys(first(chain_init))
if AbstractPPL.subsumes(@varname(s), vn)
# `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2].
dist = InverseGamma(2, 3)
b = DynamicPPL.link_transform(dist)
@test mean(mean(b(vi[vn])) for vi in chain_init) 0 atol = 0.11
elseif AbstractPPL.subsumes(@varname(m), vn)
# `m ~ Normal(0, sqrt(s))` and its constrained value is the same.
@test mean(mean(vi[vn]) for vi in chain_init) 0 atol = 0.11
else
error("Unknown variable name: $vn")
end
N = 1000
chain_init = sample(model, SampleFromUniform(), N; progress=false)

for vn in keys(first(chain_init))
if AbstractPPL.subsumes(@varname(s), vn)
# `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2].
dist = InverseGamma(2, 3)
b = DynamicPPL.link_transform(dist)
@test mean(mean(b(vi[vn])) for vi in chain_init) 0 atol = 0.11
elseif AbstractPPL.subsumes(@varname(m), vn)
# `m ~ Normal(0, sqrt(s))` and its constrained value is the same.
@test mean(mean(vi[vn]) for vi in chain_init) 0 atol = 0.11
else
error("Unknown variable name: $vn")
end
end
end
Expand Down

0 comments on commit 60f4d33

Please sign in to comment.