Skip to content

Commit

Permalink
Add logratio_proposal_density and remove is_symmetric_proposal (#54)
Browse files Browse the repository at this point in the history
* Add `logratio_proposal_density` and remove `is_symmetric_proposal`

* Bump version

* Add `issymmetric` for `StaticProposal` and add aliases

* Fix type inference problems

* Remove accidentally included code
  • Loading branch information
devmotion authored May 13, 2021
1 parent 7a02255 commit 941c046
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 101 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedMH"
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
version = "0.5.9"
version = "0.6.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
14 changes: 12 additions & 2 deletions src/AdvancedMH.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,18 @@ using Requires
import Random

# Exports
export MetropolisHastings, DensityModel, RWMH, StaticMH, StaticProposal,
RandomWalkProposal, Ensemble, StretchProposal, MALA
export
MetropolisHastings,
DensityModel,
RWMH,
StaticMH,
StaticProposal,
SymmetricStaticProposal,
RandomWalkProposal,
SymmetricRandomWalkProposal,
Ensemble,
StretchProposal,
MALA

# Reexports
export sample, MCMCThreads, MCMCDistributed
Expand Down
5 changes: 5 additions & 0 deletions src/MALA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ function q(
return q(spl.proposal(-t_cond.gradient), t.params, t_cond.params)
end

function logratio_proposal_density(
sampler::MALA{<:Proposal}, state::GradientTransition, candidate::GradientTransition
)
return q(sampler, state, candidate) - q(sampler, candidate, state)
end

"""
logdensity_and_gradient(model::DensityModel, params)
Expand Down
70 changes: 8 additions & 62 deletions src/mh-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,37 +138,6 @@ end
return expr
end

# Evaluate the likelihood of t conditional on t_cond.
function q(
spl::MetropolisHastings{<:AbstractArray},
t::Transition,
t_cond::Transition
)
# mapreduce with multiple iterators requires Julia 1.2 or later
return mapreduce(+, 1:length(spl.proposal)) do i
q(spl.proposal[i], t.params[i], t_cond.params[i])
end
end

function q(
spl::MetropolisHastings{<:Proposal},
t::Transition,
t_cond::Transition
)
return q(spl.proposal, t.params, t_cond.params)
end

function q(
spl::MetropolisHastings{<:NamedTuple},
t::Transition,
t_cond::Transition
)
# mapreduce with multiple iterators requires Julia 1.2 or later
return mapreduce(+, keys(t.params)) do k
q(spl.proposal[k], t.params[k], t_cond.params[k])
end
end

transition(sampler, model, params) = transition(model, params)
transition(model, params) = Transition(model, params)

Expand All @@ -191,31 +160,6 @@ function AbstractMCMC.step(
return transition, transition
end

"""
is_symmetric_proposal(proposal)::Bool
Implementing this for a custom proposal will allow `AbstractMCMC.step` to avoid
computing the "Hastings" part of the Metropolis-Hasting log acceptance
probability (if the proposal is indeed symmetric). By default,
`is_symmetric_proposal(proposal)` returns `false`. The user is responsible for
determining whether a custom proposal distribution is indeed symmetric. As
noted in `MetropolisHastings`, `proposal` is a `Proposal`, `NamedTuple` of
`Proposal`, or `Array{Proposal}` in the shape of your data.
"""
is_symmetric_proposal(proposal) = false

# The following univariate random walk proposals are symmetric.
is_symmetric_proposal(::RandomWalkProposal{<:Normal}) = true
is_symmetric_proposal(::RandomWalkProposal{<:MvNormal}) = true
is_symmetric_proposal(::RandomWalkProposal{<:TDist}) = true
is_symmetric_proposal(::RandomWalkProposal{<:Cauchy}) = true

# The following multivariate random walk proposals are symmetric.
is_symmetric_proposal(::RandomWalkProposal{<:AbstractArray{<:Normal}}) = true
is_symmetric_proposal(::RandomWalkProposal{<:AbstractArray{<:MvNormal}}) = true
is_symmetric_proposal(::RandomWalkProposal{<:AbstractArray{<:TDist}}) = true
is_symmetric_proposal(::RandomWalkProposal{<:AbstractArray{<:Cauchy}}) = true

# Define the other sampling steps.
# Return a 2-tuple consisting of the next sample and the the next state.
# In this case they are identical, and either a new proposal (if accepted)
Expand All @@ -231,12 +175,8 @@ function AbstractMCMC.step(
params = propose(rng, spl, model, params_prev)

# Calculate the log acceptance probability.
logα = logdensity(model, params) - logdensity(model, params_prev)

# Compute Hastings portion of ratio if proposal is not symmetric.
if !is_symmetric_proposal(spl.proposal)
logα += q(spl, params_prev, params) - q(spl, params, params_prev)
end
logα = logdensity(model, params) - logdensity(model, params_prev) +
logratio_proposal_density(spl, params_prev, params)

# Decide whether to return the previous params or the new one.
if -Random.randexp(rng) < logα
Expand All @@ -245,3 +185,9 @@ function AbstractMCMC.step(
return params_prev, params_prev
end
end

function logratio_proposal_density(
sampler::MetropolisHastings, params_prev::Transition, params::Transition
)
return logratio_proposal_density(sampler.proposal, params_prev.params, params.params)
end
119 changes: 102 additions & 17 deletions src/proposal.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
abstract type Proposal{P} end

struct StaticProposal{P} <: Proposal{P}
struct StaticProposal{issymmetric,P} <: Proposal{P}
proposal::P
end
const SymmetricStaticProposal{P} = StaticProposal{true,P}

struct RandomWalkProposal{P} <: Proposal{P}
StaticProposal(proposal) = StaticProposal{false}(proposal)
function StaticProposal{issymmetric}(proposal) where {issymmetric}
return StaticProposal{issymmetric,typeof(proposal)}(proposal)
end

struct RandomWalkProposal{issymmetric,P} <: Proposal{P}
proposal::P
end
const SymmetricRandomWalkProposal{P} = RandomWalkProposal{true,P}

RandomWalkProposal(proposal) = RandomWalkProposal{false}(proposal)
function RandomWalkProposal{issymmetric}(proposal) where {issymmetric}
return RandomWalkProposal{issymmetric,typeof(proposal)}(proposal)
end

# Random draws
Base.rand(p::Proposal, args...) = rand(Random.GLOBAL_RNG, p, args...)
Expand All @@ -26,24 +38,28 @@ end
# Random Walk #
###############

function propose(rng::Random.AbstractRNG, p::RandomWalkProposal, m::DensityModel)
return propose(rng, StaticProposal(p.proposal), m)
function propose(
rng::Random.AbstractRNG,
proposal::RandomWalkProposal{issymmetric,<:Union{Distribution,AbstractArray}},
::DensityModel
) where {issymmetric}
return rand(rng, proposal)
end

function propose(
rng::Random.AbstractRNG,
proposal::RandomWalkProposal{<:Union{Distribution,AbstractArray}},
model::DensityModel,
proposal::RandomWalkProposal{issymmetric,<:Union{Distribution,AbstractArray}},
model::DensityModel,
t
)
) where {issymmetric}
return t + rand(rng, proposal)
end

function q(
proposal::RandomWalkProposal{<:Union{Distribution,AbstractArray}},
proposal::RandomWalkProposal{issymmetric,<:Union{Distribution,AbstractArray}},
t,
t_cond
)
) where {issymmetric}
return logpdf(proposal, t - t_cond)
end

Expand All @@ -53,18 +69,18 @@ end

function propose(
rng::Random.AbstractRNG,
proposal::StaticProposal{<:Union{Distribution,AbstractArray}},
proposal::StaticProposal{issymmetric,<:Union{Distribution,AbstractArray}},
model::DensityModel,
t=nothing
)
) where {issymmetric}
return rand(rng, proposal)
end

function q(
proposal::StaticProposal{<:Union{Distribution,AbstractArray}},
proposal::StaticProposal{issymmetric,<:Union{Distribution,AbstractArray}},
t,
t_cond
)
) where {issymmetric}
return logpdf(proposal, t)
end

Expand All @@ -73,10 +89,14 @@ end
############

# function definition with abstract types requires Julia 1.3 or later
for T in (StaticProposal, RandomWalkProposal)
for T in (:StaticProposal, :RandomWalkProposal)
@eval begin
(p::$T{<:Function})() = $T(p.proposal())
(p::$T{<:Function})(t) = $T(p.proposal(t))
function (p::$T{issymmetric,<:Function})() where {issymmetric}
return $T{issymmetric}(p.proposal())
end
function (p::$T{issymmetric,<:Function})(t) where {issymmetric}
return $T{issymmetric}(p.proposal(t))
end
end
end

Expand All @@ -103,4 +123,69 @@ function q(
t_cond
)
return q(proposal(t_cond), t, t_cond)
end
end

"""
logratio_proposal_density(proposal, state, candidate)
Compute the log-ratio of the proposal densities in the Metropolis-Hastings algorithm.
The log-ratio of the proposal densities is defined as
```math
\\log \\frac{g(x | x')}{g(x' | x)},
```
where ``x`` is the current state, ``x'`` is the proposed candidate for the next state,
and ``g(y' | y)`` is the conditional probability of proposing state ``y'`` given state
``y`` (proposal density).
"""
function logratio_proposal_density(proposal::Proposal, state, candidate)
return q(proposal, state, candidate) - q(proposal, candidate, state)
end

# ratio is always 0 for symmetric proposals
logratio_proposal_density(::RandomWalkProposal{true}, state, candidate) = 0
logratio_proposal_density(::StaticProposal{true}, state, candidate) = 0

# type stable implementation for `NamedTuple`s
function logratio_proposal_density(
proposals::NamedTuple{names}, states::NamedTuple, candidates::NamedTuple
) where {names}
if @generated
args = map(names) do name
:(logratio_proposal_density(
proposals[$(QuoteNode(name))],
states[$(QuoteNode(name))],
candidates[$(QuoteNode(name))],
))
end
return :(+($(args...)))
else
return sum(names) do name
return logratio_proposal_density(
proposals[name], states[name], candidates[name]
)
end
end
end

# use recursion for `Tuple`s to ensure type stability
logratio_proposal_density(proposals::Tuple{}, states::Tuple, candidates::Tuple) = 0
function logratio_proposal_density(
proposals::Tuple{<:Proposal}, states::Tuple, candidates::Tuple
)
return logratio_proposal_density(first(proposals), first(states), first(candidates))
end
function logratio_proposal_density(proposals::Tuple, states::Tuple, candidates::Tuple)
valfirst = logratio_proposal_density(first(proposals), first(states), first(candidates))
valtail = logratio_proposal_density(
Base.tail(proposals), Base.tail(states), Base.tail(candidates)
)
return valfirst + valtail
end

# fallback for general iterators (arrays etc.) - possibly not type stable!
function logratio_proposal_density(proposals, states, candidates)
return sum(zip(proposals, states, candidates)) do (proposal, state, candidate)
return logratio_proposal_density(proposal, state, candidate)
end
end
Loading

2 comments on commit 941c046

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/36678

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.0 -m "<description of version>" 941c0464be4c5119c306c8873d1805c23dc4c94c
git push origin v0.6.0

Please sign in to comment.