Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Robust Adaptive Metropolis #106

Merged
merged 34 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
85ec534
added an initial implementation of `RAM`
torfjelde Dec 4, 2024
de519a4
added proper docs for RAM
torfjelde Dec 4, 2024
40ebb7e
fixed doctest for `RAM` + added impls of `getparams` and `setparams!!`
torfjelde Dec 4, 2024
2dec18a
added DocStringExtensions as a dep
torfjelde Dec 4, 2024
045f8c5
bump patch version
torfjelde Dec 4, 2024
755a180
attempt at making the dcotest a bit more consistent
torfjelde Dec 4, 2024
5c1c6f5
a
torfjelde Dec 4, 2024
cddf8d1
added checks for eigenvalues according to p. 13 in Vihola (2012) (in
torfjelde Dec 4, 2024
29c9078
fixed default value for `eigenvalue_lower_bound`
torfjelde Dec 5, 2024
78a5f51
applied suggestions from @mhauru
torfjelde Dec 6, 2024
652a227
more doctesting of RAM + improved docstrings
torfjelde Dec 6, 2024
5eaff52
added docstring for `RAMState`
torfjelde Dec 6, 2024
d8688fa
added proper testing of RAM
torfjelde Dec 6, 2024
f5fc301
Update src/RobustAdaptiveMetropolis.jl
torfjelde Dec 6, 2024
56ec717
added compat entries to docs
torfjelde Dec 6, 2024
da431b4
apply suggestions from @devmotion
torfjelde Dec 6, 2024
f2889a0
Merge remote-tracking branch 'origin/torfjelde/RAM' into torfjelde/RAM
torfjelde Dec 6, 2024
9247281
renamed `RAM` to `RobostMetropolisHastings` + removed the separate mo…
Dec 10, 2024
4764120
formatting
Dec 10, 2024
11f3b64
made the docstring for RAM a bit nicer
Dec 10, 2024
df4feb1
fixed doctest
Dec 10, 2024
f784492
formatting
Dec 10, 2024
45820d2
minor improvement to docstring of RAM
Dec 10, 2024
7405a19
fused scalar operations
Dec 10, 2024
5dce265
added dimensionality check of the provided `S` matrix
Dec 10, 2024
5ee44e3
fixed typo
Dec 10, 2024
37a2189
Update docs/src/api.md
torfjelde Dec 10, 2024
5193119
use `randn` instead of `rand` for initialisation
Dec 10, 2024
d4a144e
added an explanation of the `min`
Dec 10, 2024
6295e78
Update test/RobustAdaptiveMetropolis.jl
torfjelde Dec 10, 2024
6f8fda4
use explicit `Cholesky` constructor for backwards compat
Dec 10, 2024
5815a9b
Fix typo: ```` -> ```
mhauru Dec 10, 2024
1b38ca6
formatted according to `blue`
Dec 10, 2024
f426d0d
Update src/RobustAdaptiveMetropolis.jl
torfjelde Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "AdvancedMH"
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
version = "0.8.4"
version = "0.8.5"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Expand All @@ -26,6 +27,7 @@ AdvancedMHStructArraysExt = "StructArrays"
AbstractMCMC = "5.6"
DiffResults = "1"
Distributions = "0.25"
DocStringExtensions = "0.9"
FillArrays = "1"
ForwardDiff = "0.10"
LinearAlgebra = "1.6"
Expand Down
5 changes: 5 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
devmotion marked this conversation as resolved.
Show resolved Hide resolved

[compat]
Documenter = "1"
6 changes: 6 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ MetropolisHastings
```@docs
DensityModel
```

## Samplers

```@docs
RAM
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
```
4 changes: 4 additions & 0 deletions src/AdvancedMH.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,8 @@ include("mh-core.jl")
include("emcee.jl")
include("MALA.jl")

include("RobustAdaptiveMetropolis.jl")
using .RobustAdaptiveMetropolis
export RAM
devmotion marked this conversation as resolved.
Show resolved Hide resolved

end # module AdvancedMH
205 changes: 205 additions & 0 deletions src/RobustAdaptiveMetropolis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
module RobustAdaptiveMetropolis
devmotion marked this conversation as resolved.
Show resolved Hide resolved

using Random, LogDensityProblems, LinearAlgebra, AbstractMCMC
using DocStringExtensions: FIELDS

using AdvancedMH: AdvancedMH

export RAM

# TODO: Should we generalise this arbitrary symmetric proposals?
"""
RAM

Robust Adaptive Metropolis-Hastings (RAM).

This is a simple implementation of the RAM algorithm described in [^VIH12].

# Fields

$(FIELDS)

# Examples

The following demonstrates how to implement a simple Gaussian model and sample from it using the RAM algorithm.

```jldoctest
julia> using AdvancedMH, Random, Distributions, MCMCChains, LogDensityProblems, LinearAlgebra

julia> # Define a Gaussian with zero mean and some covariance.
struct Gaussian{A}
Σ::A
end

julia> # Implement the LogDensityProblems interface.
LogDensityProblems.dimension(model::Gaussian) = size(model.Σ, 1)

julia> function LogDensityProblems.logdensity(model::Gaussian, x)
d = LogDensityProblems.dimension(model)
return logpdf(MvNormal(zeros(d),model.Σ), x)
end

julia> LogDensityProblems.capabilities(::Gaussian) = LogDensityProblems.LogDensityOrder{0}()
devmotion marked this conversation as resolved.
Show resolved Hide resolved

julia> # Construct the model. We'll use a correlation of 0.5.
model = Gaussian([1.0 0.5; 0.5 1.0]);

julia> # Number of samples we want in the resulting chain.
num_samples = 10_000;

julia> # Number of warmup steps, i.e. the number of steps to adapt the covariance of the proposal.
# Note that these are not included in the resulting chain, as `discard_initial=num_warmup`
# by default in the `sample` call. To include them, pass `discard_initial=0` to `sample`.
num_warmup = 10_000;
devmotion marked this conversation as resolved.
Show resolved Hide resolved

julia> # Set the seed so get some consistency.
Random.seed!(1234);
devmotion marked this conversation as resolved.
Show resolved Hide resolved

julia> # Sample!
chain = sample(model, RAM(), 10_000; chain_type=Chains, num_warmup=10_000, progress=false, initial_params=zeros(2));

julia> norm(cov(Array(chain)) - [1.0 0.5; 0.5 1.0]) < 0.2
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
true
mhauru marked this conversation as resolved.
Show resolved Hide resolved
```

# References
[^VIH12]: Vihola (2012) Robust adaptive Metropolis algorithm with coerced acceptance rate, Statistics and computing.
"""
Base.@kwdef struct RAM{T,A<:Union{Nothing,AbstractMatrix{T}}} <: AdvancedMH.MHSampler
devmotion marked this conversation as resolved.
Show resolved Hide resolved
"target acceptance rate"
α::T=0.234
"negative exponent of the adaptation decay rate"
γ::T=0.6
"initial lower-triangular Cholesky factor"
S::A=nothing
"lower bound on eigenvalues of the adapted covariance matrix"
eigenvalue_lower_bound::T=0.0
"upper bound on eigenvalues of the adapted covariance matrix"
eigenvalue_upper_bound::T=Inf
end

# TODO: Should we record anything like the acceptance rates?
struct RAMState{T1,L,A,T2,T3}
x::T1
logprob::L
S::A
logα::T2
η::T3
iteration::Int
isaccept::Bool
mhauru marked this conversation as resolved.
Show resolved Hide resolved
end

AbstractMCMC.getparams(state::RAMState) = state.x
AbstractMCMC.setparams!!(state::RAMState, x) = RAMState(x, state.logprob, state.S, state.logα, state.η, state.iteration, state.isaccept)

function step_inner(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::RAM,
state::RAMState
)
# This is the initial state.
f = model.logdensity
d = LogDensityProblems.dimension(f)
devmotion marked this conversation as resolved.
Show resolved Hide resolved

# Sample the proposal.
x = state.x
U = randn(rng, d)
x_new = x + state.S * U
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

# Compute the acceptance probability.
lp = state.logprob
lp_new = LogDensityProblems.logdensity(f, x_new)
logα = min(lp_new - lp, zero(lp)) # `min` because we'll use this for updating
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe only bound it in the update of S? It seems at least easier to read if the bounding is kept together with the part of the algorithm where it's actually needed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But IMO this makes it a bit strange if we then put the unbounded logα in the resulting state, since this is not the quantity used to update the S 😕

And for the purposes it is user here, it doesn't actually matter if it's bounded or not, right? As in, it's equivalent here, but not equivalent in the S update, hence it seems somewhat natural for me to just do it once and for all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is not the quantity used to update the S

It is, isn't it? The update is just slightly different. Otherwise with the same reasoning you could also argue that one should only store α (or maybe the difference to the targeted α?) since only these are used to update S.

In the end I guess it doesn't matter as it's only used in these two places. It just felt strange conceptually to bound it here, in particular since it seemed you already felt the need to explain this decision with a comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But alpha represents the acceptance probability, no? So clamping it like this is technically always what you should do, but most of time we don't because it's unnecessary for sampling according to this probability.

However, if the user wants to actually look at the resulting acceptance probs, then it's a question of: do we want the user to do

mean(exp, getproperty.(states, :logα)))

or

mean(exp, min.(1, getproperty.(states, :logα)))

In my head, the user expects to do the former, not the latter 🤷

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see your point. Even though I think neither of the two alternatives is particularly user-friendly, IMO a separate API for acceptance probabilities would be better.

In any case, I think I wouldn't even have commented on this line if the comment # min because we'll use this for updating would not have been there. So maybe just remove the comment?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, but then I'm worried someone might come along later and go "wait, that's not needed; let's just remove this unnecessary min", not realizing that we'll use this for adaptation 😬

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add an @assert logα <= zero(logα) at the top of ram_adapt?

# TODO: use `randexp` instead.
isaccept = log(rand(rng)) < logα
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

return x_new, lp_new, U, logα, isaccept
end

function adapt(sampler::RAM, state::RAMState, logα::Real, U::AbstractVector)
# Update `
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
Δα = exp(logα) - sampler.α
S = state.S
# TODO: Make this configurable by defining a more general path.
η = state.iteration^(-sampler.γ)
ΔS = η * abs(Δα) * S * U / norm(U)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Maybe do in-place and then have the user extract it with a callback if they really want it.
S_new = if sign(Δα) == 1
# One rank update.
LinearAlgebra.lowrankupdate(Cholesky(S), ΔS).L
else
# One rank downdate.
LinearAlgebra.lowrankdowndate(Cholesky(S), ΔS).L
end
return S_new, η
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::RAM;
initial_params=nothing,
kwargs...
)
# This is the initial state.
f = model.logdensity
d = LogDensityProblems.dimension(f)

# Initial parameter state.
x = initial_params === nothing ? rand(rng, d) : initial_params
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe ensure consistent types here as well?

Suggested change
x = initial_params === nothing ? rand(rng, d) : initial_params
x = initial_params === nothing ? rand(rng, eltype(sampler.γ), d) : initial_params

By the way, rand(rng, d) doesn't seem a good choice in general? The algorithm requires that you start with a point in the support of the target distribution and it's not clear if the target density is zero for this point. I wonder if it requires something like https://github.com/TuringLang/EllipticalSliceSampling.jl/blob/3296ae3566d329207875216837e65eeec3b809b2/src/interface.jl#L20-L29 in EllipticalSliceSampling.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this is using a RWMH as the main kernel, so IMO we're already assuming unconstrained support for this to be valid

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And more generally, happy to deal with better initialisation, buuuuut prefer to do this in a separate PR as I'm imagining this will require some discussion

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so IMO we're already assuming unconstrained support

But why prefer rand over randn in that case?

But I agree, probably this question (rand/randn/dedicated API) should be addressed separately.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why prefer rand over randn in that case?

Just did it quickly because I have some vague memory that it's generally preferred to do initialisation in a box near 0 for most of the linking transformations (believe this is the moticvation behind SampleFromUniform in DPPL, though I think it technically initialses from a cube centered on 0?).

But it's w/e to me here; we need a better way in general for this, so I'll just change it to randn 👍

# Initialize the Cholesky factor of the covariance matrix.
S = LowerTriangular(sampler.S === nothing ? diagm(0 => ones(eltype(sampler.γ), d)) : sampler.S)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

# Constuct the initial state.
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
lp = LogDensityProblems.logdensity(f, x)
state = RAMState(x, lp, S, 0.0, 0, 1, true)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

return AdvancedMH.Transition(x, lp, true), state
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::RAM,
state::RAMState;
kwargs...
)
# Take the inner step.
x_new, lp_new, U, logα, isaccept = step_inner(rng, model, sampler, state)
# Accept / reject the proposal.
state_new = RAMState(isaccept ? x_new : state.x, isaccept ? lp_new : state.logprob, state.S, logα, state.η, state.iteration + 1, isaccept)
mhauru marked this conversation as resolved.
Show resolved Hide resolved
return AdvancedMH.Transition(state_new.x, state_new.logprob, state_new.isaccept), state_new
end

function valid_eigenvalues(S, lower_bound, upper_bound)
# Short-circuit if the bounds are the default.
(lower_bound == 0 && upper_bound == Inf) && return true
# Note that this is just the diagonal when `S` is triangular.
eigenvals = LinearAlgebra.eigvals(S)
return all(lower_bound .<= eigenvals .<= upper_bound)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end

function AbstractMCMC.step_warmup(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::RAM,
state::RAMState;
kwargs...
)
# Take the inner step.
x_new, lp_new, U, logα, isaccept = step_inner(rng, model, sampler, state)
# Adapt the proposal.
S_new, η = adapt(sampler, state, logα, U)
# Check that `S_new` has eigenvalues in the desired range.
if !valid_eigenvalues(S_new, sampler.eigenvalue_lower_bound, sampler.eigenvalue_upper_bound)
# In this case, we just keep the old `S` (p. 13 in Vihola, 2012).
S_new = state.S
end

# Update state.
state_new = RAMState(isaccept ? x_new : state.x, isaccept ? lp_new : state.logprob, S_new, logα, η, state.iteration + 1, isaccept)
return AdvancedMH.Transition(state_new.x, state_new.logprob, state_new.isaccept), state_new
end

end
Loading