Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Turing integration tests #733

Merged
merged 15 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ jobs:

- uses: julia-actions/julia-runtest@v1
env:
GROUP: All
JULIA_NUM_THREADS: ${{ matrix.runner.num_threads }}

- uses: julia-actions/julia-processcoverage@v1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "docs", "test", "test/turing"])'
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "docs", "test"])'
2 changes: 0 additions & 2 deletions .github/workflows/JuliaPre.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,3 @@ jobs:
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
GROUP: DynamicPPL
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.31.5"
version = "0.32.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
5 changes: 0 additions & 5 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,6 @@ TypedVarInfo

One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form.

```@docs
mhauru marked this conversation as resolved.
Show resolved Hide resolved
link!
invlink!
```

```@docs
set_flag!
unset_flag!
Expand Down
62 changes: 46 additions & 16 deletions src/test_utils/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,28 +323,30 @@ function varnames(model::Model{typeof(demo_assume_dot_observe)})
return [@varname(s), @varname(m)]
end

@model function demo_assume_observe_literal()
# `assume` and literal `observe`
@model function demo_assume_multivariate_observe_literal()
# multivariate `assume` and literal `observe`
s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
m ~ MvNormal(zeros(2), Diagonal(s))
[1.5, 2.0] ~ MvNormal(m, Diagonal(s))

return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
end
function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
function logprior_true(model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m)
s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
m_dist = MvNormal(zeros(2), Diagonal(s))
return logpdf(s_dist, s) + logpdf(m_dist, m)
end
function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
function loglikelihood_true(
model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m
)
return logpdf(MvNormal(m, Diagonal(s)), [1.5, 2.0])
end
function logprior_true_with_logabsdet_jacobian(
model::Model{typeof(demo_assume_observe_literal)}, s, m
model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m
)
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end
function varnames(model::Model{typeof(demo_assume_observe_literal)})
function varnames(model::Model{typeof(demo_assume_multivariate_observe_literal)})
return [@varname(s), @varname(m)]
end

Expand Down Expand Up @@ -377,26 +379,50 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)})
return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])]
end

@model function demo_assume_literal_dot_observe()
@model function demo_assume_observe_literal()
# univariate `assume` and literal `observe`
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
1.5 ~ Normal(m, sqrt(s))
2.0 ~ Normal(m, sqrt(s))

return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
end
function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m)
end
function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0)
end
function logprior_true_with_logabsdet_jacobian(
model::Model{typeof(demo_assume_observe_literal)}, s, m
)
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end
function varnames(model::Model{typeof(demo_assume_observe_literal)})
return [@varname(s), @varname(m)]
end

@model function demo_assume_dot_observe_literal()
# `assume` and literal `dot_observe`
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
[1.5, 2.0] .~ Normal(m, sqrt(s))

return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
end
function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m)
function logprior_true(model::Model{typeof(demo_assume_dot_observe_literal)}, s, m)
return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m)
end
function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m)
function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe_literal)}, s, m)
return loglikelihood(Normal(m, sqrt(s)), [1.5, 2.0])
end
function logprior_true_with_logabsdet_jacobian(
model::Model{typeof(demo_assume_literal_dot_observe)}, s, m
model::Model{typeof(demo_assume_dot_observe_literal)}, s, m
)
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end
function varnames(model::Model{typeof(demo_assume_literal_dot_observe)})
function varnames(model::Model{typeof(demo_assume_dot_observe_literal)})
return [@varname(s), @varname(m)]
end

Expand Down Expand Up @@ -574,8 +600,9 @@ const DemoModels = Union{
Model{typeof(demo_assume_multivariate_observe)},
Model{typeof(demo_dot_assume_observe_index)},
Model{typeof(demo_assume_dot_observe)},
Model{typeof(demo_assume_literal_dot_observe)},
Model{typeof(demo_assume_dot_observe_literal)},
Model{typeof(demo_assume_observe_literal)},
Model{typeof(demo_assume_multivariate_observe_literal)},
Model{typeof(demo_dot_assume_observe_index_literal)},
Model{typeof(demo_assume_submodel_observe_index_literal)},
Model{typeof(demo_dot_assume_observe_submodel)},
Expand All @@ -585,7 +612,9 @@ const DemoModels = Union{
}

const UnivariateAssumeDemoModels = Union{
Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)}
Model{typeof(demo_assume_dot_observe)},
Model{typeof(demo_assume_dot_observe_literal)},
Model{typeof(demo_assume_observe_literal)},
}
function posterior_mean(model::UnivariateAssumeDemoModels)
return (s=49 / 24, m=7 / 6)
Expand All @@ -609,7 +638,7 @@ const MultivariateAssumeDemoModels = Union{
Model{typeof(demo_assume_index_observe)},
Model{typeof(demo_assume_multivariate_observe)},
Model{typeof(demo_dot_assume_observe_index)},
Model{typeof(demo_assume_observe_literal)},
Model{typeof(demo_assume_multivariate_observe_literal)},
Model{typeof(demo_dot_assume_observe_index_literal)},
Model{typeof(demo_assume_submodel_observe_index_literal)},
Model{typeof(demo_dot_assume_observe_submodel)},
Expand Down Expand Up @@ -759,9 +788,10 @@ const DEMO_MODELS = (
demo_assume_multivariate_observe(),
demo_dot_assume_observe_index(),
demo_assume_dot_observe(),
demo_assume_observe_literal(),
demo_assume_multivariate_observe_literal(),
demo_dot_assume_observe_index_literal(),
demo_assume_literal_dot_observe(),
demo_assume_dot_observe_literal(),
demo_assume_observe_literal(),
demo_assume_submodel_observe_index_literal(),
demo_dot_assume_observe_submodel(),
demo_dot_assume_dot_observe_matrix(),
Expand Down
44 changes: 0 additions & 44 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1221,27 +1221,6 @@ function link!!(
return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl, model)
end

"""
link!(vi::VarInfo, spl::Sampler)

Transform the values of the random variables sampled by `spl` in `vi` from the support
of their distributions to the Euclidean space and set their corresponding `"trans"`
flag values to `true`.
"""
function link!(vi::VarInfo, spl::AbstractSampler)
Base.depwarn(
"`link!(varinfo, sampler)` is deprecated, use `link!!(varinfo, sampler, model)` instead.",
:link!,
)
return _link!(vi, spl)
end
function link!(vi::VarInfo, spl::AbstractSampler, spaceval::Val)
mhauru marked this conversation as resolved.
Show resolved Hide resolved
Base.depwarn(
"`link!(varinfo, sampler, spaceval)` is deprecated, use `link!!(varinfo, sampler, model)` instead.",
:link!,
)
return _link!(vi, spl, spaceval)
end
function _link!(vi::UntypedVarInfo, spl::AbstractSampler)
# TODO: Change to a lazy iterator over `vns`
vns = _getvns(vi, spl)
Expand Down Expand Up @@ -1319,29 +1298,6 @@ function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, mode
return maybe_invlink_before_eval!!(t, vi, context, model)
end

"""
invlink!(vi::VarInfo, spl::AbstractSampler)

Transform the values of the random variables sampled by `spl` in `vi` from the
Euclidean space back to the support of their distributions and sets their corresponding
`"trans"` flag values to `false`.
"""
function invlink!(vi::VarInfo, spl::AbstractSampler)
Base.depwarn(
"`invlink!(varinfo, sampler)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.",
:invlink!,
)
return _invlink!(vi, spl)
end

function invlink!(vi::VarInfo, spl::AbstractSampler, spaceval::Val)
Base.depwarn(
"`invlink!(varinfo, sampler, spaceval)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.",
:invlink!,
)
return _invlink!(vi, spl, spaceval)
end

function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler)
vns = _getvns(vi, spl)
if istrans(vi, vns[1])
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand Down
39 changes: 39 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,43 @@
end
end
end

@testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin
# Failing model
t = 1:0.05:8
σ = 0.3
y = @. rand(sin(t) + Normal(0, σ))
@model function state_space(y, TT, ::Type{T}=Float64) where {T}
# Priors
α ~ Normal(y[1], 0.001)
τ ~ Exponential(1)
η ~ filldist(Normal(0, 1), TT - 1)
σ ~ Exponential(1)
# create latent variable
x = Vector{T}(undef, TT)
x[1] = α
for t in 2:TT
x[t] = x[t - 1] + η[t - 1] * τ
end
# measurement model
y ~ MvNormal(x, σ^2 * I)
return x
end
model = state_space(y, length(t))

# Dummy sampling algorithm for testing. The test case can only be replicated
# with a custom sampler, it doesn't work with SampleFromPrior(). We need to
# overload assume so that model evaluation doesn't fail due to a lack
# of implementation
struct MyEmptyAlg end
DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = ()
DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) =
DynamicPPL.assume(dist, vn, vi)

# Compiling the ReverseDiff tape used to fail here
spl = Sampler(MyEmptyAlg())
vi = VarInfo(model)
ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl))
@test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any
end
end
Loading
Loading