diff --git a/test/Project.toml b/test/Project.toml index e536fbfa8..5c26b35de 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -19,6 +19,7 @@ LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" diff --git a/test/test_util.jl b/test/test_util.jl index f1325b729..0611c594f 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -110,3 +110,36 @@ function modify_value_representation(nt::NamedTuple) end return modified_nt end + +""" + make_chain_from_prior([rng,] model, n_iters) + +Construct an MCMCChains.Chains object by sampling from the prior of `model` for +`n_iters` iterations. +""" +function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int) + # Sample from the prior + varinfos = [VarInfo(rng, model) for _ in 1:n_iters] + # Extract all varnames found in any dictionary. Doing it this way guards + # against the possibility of having different varnames in different + # dictionaries, e.g. for models that have dynamic variables / array sizes + varnames = OrderedSet{VarName}() + # Convert each varinfo into an OrderedDict of vns => params. + # We have to use varname_and_value_leaves so that each parameter is a scalar + dicts = map(varinfos) do t + vals = DynamicPPL.values_as(t, OrderedDict) + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + tuples = mapreduce(collect, vcat, iters) + push!(varnames, map(first, tuples)...) + OrderedDict(tuples) + end + # Convert back to list + varnames = collect(varnames) + # Construct matrix of values + vals = [get(dict, vn, missing) for dict in dicts, vn in varnames] + # Construct and return the Chains object + return Chains(vals, varnames) +end +function make_chain_from_prior(model::Model, n_iters::Int) + return make_chain_from_prior(Random.default_rng(), model, n_iters) +end