You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
using LinearAlgebra
using Enzyme
using GaussianDistributions
using Random
## MODEL DEFINITION ###########################################################struct LinearGaussianProcess{
T,
ΦT<:AbstractMatrix{T},
ΣT<:AbstractMatrix{T}
}
ϕ::ΦT
Σ::ΣT
functionLinearGaussianProcess(ϕ::ΦT, Σ::ΣT) where {
T,
ΦT<:AbstractMatrix{T},
ΣT<:AbstractMatrix{T}
}
@assertsize(ϕ,1) ==size(Σ,1) ==size(Σ,2)
returnnew{T, ΦT, ΣT}(ϕ, Σ)
endendstruct LinearGaussianModel{
ΘT,
TT<:LinearGaussianProcess{ΘT},
OT<:LinearGaussianProcess{ΘT}
}
transition::TT
observation::OT
dims::Tuple{Int, Int}end## KALMAN FILTER ##############################################################functionkalman_filter(
model::LinearGaussianModel,
init_state::Gaussian,
observations::AbstractVector{T}
) where {T}
log_evidence =zero(T)
particle = init_state
A = model.transition.ϕ
Q = model.transition.Σ
B = model.observation.ϕ
R = model.observation.Σ
for obs in observations
particle =let μ = particle.μ, Σ = particle.Σ
Gaussian(A*μ, A*Σ*A'+ Q)
end
particle, residual, S = GaussianDistributions.correct(
particle,
Gaussian([obs], R), B
)
log_evidence += GaussianDistributions.logpdf(
Gaussian(zero(residual), Symmetric(S)),
residual
)
endreturn log_evidence
end## DEMONSTRATION ############################################################### θ should be a single element vector for this demonstrationfunctionbuild_model(θ::AbstractVector{T}) where {T}
trans =LinearGaussianProcess(T[1;;], Diagonal(θ))
obs =LinearGaussianProcess(T[1;;], Diagonal(T[1]))
returnLinearGaussianModel(trans, obs, (1,1))
end# log likelihood functionfunctionlogℓ(θ::AbstractArray{T}, data) where {T}
model =build_model(θ)
init_state =Gaussian(zeros(T, 1), T[1;;])
returnkalman_filter(model, init_state, data)
end# data generation (with unit covariance)
rng =MersenneTwister(1234)
data =cumsum(randn(rng, 100)) .+randn(rng, 100)
# ensure that log likelihood looks stablelogℓ([1.0], data)
using Reactant
# Reactant will automatically upgrade code to use GPU/TPU where available# For more apples to apples tests lets force CPU
Reactant.set_default_backend("cpu")
ra = Reactant.to_rarray([1.0])
gradfunc = Reactant.@compile Enzyme.gradient(Reverse, logℓ, ra, Const(data))
@avik-pal mind taking a look at this?
hits
The text was updated successfully, but these errors were encountered: