Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multisampler #148

Open
wants to merge 38 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
df6f389
started working on multisampler
mwien Feb 29, 2024
e785e56
minor fixes
mwien Feb 29, 2024
f4cb6c5
add full multisampler
Mar 1, 2024
82ed06c
remove comment
Mar 1, 2024
ef04dea
remove printlns
Mar 1, 2024
c33bd49
Fixes
mschauer Mar 1, 2024
113866d
define penalty
mschauer Mar 1, 2024
030c087
Tuning
mschauer Mar 1, 2024
37ca082
More extensive tests
mschauer Mar 1, 2024
e69b300
rename for clarity
mwien Mar 1, 2024
27fa658
fix copy
mwien Mar 1, 2024
6698330
minor rename
mwien Mar 1, 2024
f8e7b20
remove comment
mwien Mar 1, 2024
00fea92
Fix test
mschauer Mar 1, 2024
953bbb2
Merge pull request #1 from mwien/multisampler2
mwien Mar 1, 2024
0d8f92e
I don't think we can save all graphs
mschauer Mar 2, 2024
5b12aac
fix comment
mwien Mar 2, 2024
0009661
Update src/multisampler.jl
mschauer Mar 2, 2024
35ad2e6
Ups
mschauer Mar 3, 2024
df05f9c
Make sure running init first
mschauer Mar 3, 2024
3f44570
Tricky business involving assuring that log(Pi) < 0
mschauer Mar 3, 2024
750f225
Count kills
mschauer Mar 3, 2024
aea3578
Fixes
mschauer Mar 3, 2024
c2ec4d0
Once more
mschauer Mar 3, 2024
112246d
Update src/multisampler.jl
mschauer Mar 3, 2024
a289505
Update src/multisampler.jl
mschauer Mar 3, 2024
bc4a3f0
Older versions do not support typed globals
mschauer Mar 3, 2024
6256f22
Factoring
mschauer Mar 6, 2024
cc38750
Cosmetics
mschauer Mar 6, 2024
0c5b3fe
Return a bit later
mschauer Mar 6, 2024
6fd67cf
Cosmetics
mschauer Mar 7, 2024
92d8396
Show temp
mschauer Mar 7, 2024
698dbde
Bug fix (never copy dead states)
mschauer Mar 7, 2024
7e3db9e
Need to monitor time.
mschauer Mar 7, 2024
2cec3f4
Score
mschauer Mar 7, 2024
110f294
Move opposite direction
mschauer Mar 7, 2024
67c730b
No need for baseline
mschauer Mar 7, 2024
0e11569
Allow stopping at target
mschauer Mar 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.15.1"

[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Expand All @@ -26,6 +27,17 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
TabularDisplay = "3eeacb1d-13c2-54cc-9b18-30c86af3cadb"
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"

[weakdeps]
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"

[extensions]
GraphMakieExt = "GraphMakie"
GraphRecipesExt = ["GraphRecipes", "Plots"]
TikzGraphsExt = "TikzGraphs"

[compat]
Combinatorics = "1.0"
DelimitedFiles = "1.6, 1.7, 1.8, 1.9"
Expand Down Expand Up @@ -55,11 +67,6 @@ ThreadsX = "0.1"
TikzGraphs = "1.3, 1.4"
julia = "1.6, 1.7, 1.8, 1.9, 1.10"

[extensions]
GraphMakieExt = "GraphMakie"
GraphRecipesExt = ["GraphRecipes", "Plots"]
TikzGraphsExt = "TikzGraphs"

[extras]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Expand All @@ -71,9 +78,3 @@ TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"

[targets]
test = ["Test", "StatsBase", "DelimitedFiles"]

[weakdeps]
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"
3 changes: 3 additions & 0 deletions src/CausalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ using Base.Iterators
using Memoization, LRUCache
using ThreadsX
using LinkedLists
using DataStructures

import Base: ==, show

export multisampler
export exactscorebased
export ancestors, descendants, alt_test_dsep, test_covariate_adjustment, alt_test_backdoor, find_dsep, find_min_dsep, find_covariate_adjustment, find_backdoor_adjustment, find_frontdoor_adjustment, find_min_covariate_adjustment, find_min_backdoor_adjustment, find_min_frontdoor_adjustment, list_dseps, list_covariate_adjustment, list_backdoor_adjustment, list_frontdoor_adjustment
export dsep, skeleton, gausscitest, dseporacle, partialcor
Expand Down Expand Up @@ -56,6 +58,7 @@ include("dag_sampler.jl")
include("misc2.jl")
include("exact.jl")
#include("mcs.jl")
include("multisampler.jl")

# Compatibility with the new "Package Extensions" (https://github.com/JuliaLang/julia/pull/47695)
const EXTENSIONS_SUPPORTED = isdefined(Base, :get_extension)
Expand Down
142 changes: 142 additions & 0 deletions src/multisampler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
struct Sample
g::DiGraph
τ::Float64
dir::Int8
total::Int32
scoreval::Float64
end

struct Action
τ::Float64
apply::Function
args::Tuple{Vararg{Any}}
end

function expcoldness(τ, k=0.0005)
return exp(k*τ)
end

function Dexpcoldness(τ, k=0.0005)
return k*exp(k*τ)
end

function init(_, _, nextτ, g, dir, total, scoreval)
return Sample(g, nextτ, dir, total, scoreval)
end

function applyup(samplers, i, nextτ, x, y, T, Δscoreval)
prevsample = last(samplers[i])
g = next_CPDAG(prevsample.g, :up, x, y, T)
return Sample(g, nextτ, prevsample.dir, prevsample.total+1, prevsample.scoreval + Δscoreval)
end

function applydown(samplers, i, nextτ, x, y, H, Δscoreval)
prevsample = last(samplers[i])
g = next_CPDAG(prevsample.g, :down, x, y, H)
return Sample(g, nextτ, prevsample.dir, prevsample.total-1, prevsample.scoreval + Δscoreval)
end

function applyflip(samplers, i, nextτ)
prevsample = last(samplers[i])
return Sample(prevsample.g, nextτ, -1*prevsample.dir, prevsample.total, prevsample.scoreval)
end

function applycopy(samplers, _, nextτ, j)
copysample = last(samplers[j])
return Sample(copysample.g, nextτ, copysample.dir, copysample.total, copysample.scoreval)
end

# for starters without turn move
function sampleaction(samplers, i, M, balance, prior, score, σ, ρ, κ, coldness, Dcoldness)
# preprocess
prevsample = last(samplers[i])
sup, sdown, Δscorevalup, Δscorevaldown, argsup, argsdown = count_moves_new(prevsample.g, κ, balance, prior, score, coldness(prevsample.τ), prevsample.total)

# propose moves
λdir = prevsample.dir == 1 ? sup : sdown
λupdown = sup + sdown
λflip = max(prevsample.dir*(-sup + sdown), 0.0)
λterm = exp(ULogarithmic, 0.0)*Dcoldness(prevsample.τ) * coldness(prevsample.τ) * prevsample.scoreval # TODO: prior
mschauer marked this conversation as resolved.
Show resolved Hide resolved
Δτdir = randexp()/(ρ*λdir)
Δτupdown = randexp()/(σ*λupdown)
Δτflip = randexp()/(ρ*λflip)
Δτterm = randexp()/(λterm)

Δτmin = min(Δτdir, Δτupdown, Δτflip, Δτterm)

if Δτdir == Δτmin
if prevsample.dir == 1
return Action(prevsample.τ + Δτdir, applyup, (argsup..., Δscorevalup))
else
return Action(prevsample.τ + Δτdir, applydown, (argsdown..., Δscorevaldown))
end
end

if Δτupdown == Δτmin
λup = sup
if rand() < λup/λupdown
return Action(prevsample.τ + Δτupdown, applyup, (argsup..., Δscorevalup))
else
return Action(prevsample.τ + Δτupdown, applydown, (argsdown..., Δscorevaldown))
end
end

if Δτflip == Δτmin
return Action(prevsample.τ + Δτflip, applyflip, ())
end

if Δτterm == Δτmin
return Action(prevsample.τ + Δτterm, applycopy, (rand(1:M),))
mschauer marked this conversation as resolved.
Show resolved Hide resolved
end

@assert false
end

function save!(as, a)
if length(as) == 0
push!(as, a)
else
as[end] = a
end
end

# remark: chose κ = n-1 as default
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, better

function multisampler(n, G = (DiGraph(n), 0); M = 10, balance = metropolis_balance, prior = (_,_) -> 1.0, score=UniformScore(), σ = 0.0, ρ = 1.0, κ = n - 1, iterations = min(3*n^2, 50000), schedule=(expcoldness, Dexpcoldness)) #, verbose = false, save = true)
if κ >= n
κ = n - 1
@warn "Truncate κ to $κ"
end
coldness, Dcoldness = schedule

# init M samplers
samplers = [Vector{Sample}() for _ = 1:M]
nextaction = Vector{Action}(undef, M)
queue = PriorityQueue{Int32, Float64}()

for i = 1:M
nextaction[i] = Action(0.0, init, (first(G), 1, last(G), 0.0)) # pass correct initial score?!
enqueue!(queue, i, 0.0)
end

# todo: multiply iterations by M to keep passed iteration number indep of M?
# could also stop if one sampler has more than iterations many samples
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One can make the argument that the right stopping criterion is the total number of samples of all samplers

# but then @showprogress does not work so nicely?!
iterations *= M
bestgraph = DiGraph(n)
bestscore = 0.0 # fix if correct initial score is given above

@showprogress for _ in 1:iterations
i = dequeue!(queue)
nextsample = nextaction[i].apply(samplers, i, nextaction[i].τ, nextaction[i].args...)
if nextsample.scoreval > bestscore
bestgraph = nextsample.g
bestscore = nextsample.scoreval
end
save!(samplers[i], nextsample)
nextaction[i] = sampleaction(samplers, i, M, balance, prior, score, σ, ρ, κ, coldness, Dcoldness)
enqueue!(queue, i, nextaction[i].τ)
end


return bestgraph, samplers
end
77 changes: 77 additions & 0 deletions test/multisampler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
using Random, CausalInference, StatsBase, Statistics, Test, Graphs, LinearAlgebra
@testset "MultiSampler" begin
Random.seed!(1)

N = 500 # number of data points

# define simple linear model with added noise
x = randn(N)
v = x + randn(N)*0.25
w = x + randn(N)*0.25
z = v + w + randn(N)*0.25
s = z + randn(N)*0.25

df = (x=x, v=v, w=w, z=z, s=s)
iterations = 5_000
penalty = 2.0 # increase to get more edges in truth
n = length(df) # vertices
Random.seed!(101)
C = cor(CausalInference.Tables.matrix(df))
score = GaussianScore(C, N, penalty)
bestgraph, samplers = multisampler(n; score, σ=2.0, iterations)
#posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)

# maximum aposteriori estimate
MAP = [1=>2, 1=>3, 2=>1, 2=>4, 3=>1, 3=>4, 4=>5]
@test bestgraph == digraph(MAP, n)
cm = sort(countmap(vpairs.(getfield.(last.(samplers), :g))), byvalue=true, rev=true)
@test first(cm).first == MAP
end #testset

@testset "MultiSampler" begin
Random.seed!(1)

N = 200 # number of data points

# define simple linear model with added noise
x = randn(N)
v = x + randn(N)*0.25
w = x + randn(N)*0.25
z = v + w + randn(N)*0.25
s = z + randn(N)*0.25

df = (x=x, v=v, w=w, z=z, s=s)
iterations = 500
penalty = 2.0 # increase to get more edges in truth
n = length(df) # vertices
Random.seed!(101)
C = cor(CausalInference.Tables.matrix(df))
score = GaussianScore(C, N, penalty)
M = 2000
bestgraph, samplers = multisampler(n; M, σ=2.0, score, iterations)
coldness = CausalInference.expcoldness(minimum(getfield.(last.(samplers), :τ)))

gs = causalzigzag(n; score, κ=n-1, coldness, iterations)
graphs, graph_pairs, hs, τs, ws, ts, scores = CausalInference.unzipgs(gs)
posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)


# maximum aposteriori estimate
MAP = [1=>2, 1=>3, 2=>1, 2=>4, 3=>1, 3=>4, 4=>5]
@test bestgraph == digraph(MAP, n)
cm = sort((proportionmap(vpairs.(getfield.(last.(samplers), :g)))), byvalue=true, rev=true)
@test first(cm).first == MAP
logΠ = map(g->score_dag(pdag2dag!(digraph(g, n)), score), collect(keys(cm)))
Π = normalize(exp.(coldness*(logΠ .- maximum(logΠ))), 1)
Πhat = normalize(collect(values(cm)), 1)
@show coldness
display([Π Πhat])
s = 0.0
for (i, k) in enumerate(keys(cm))
s += posterior[k]
#@show cm[k] Π[i]
end
@show s
@test s > 0.98
@test norm(collect(values(cm)) - Π) < 0.02
end #testset
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ include("witness.jl")
include("fci.jl")
include("klentropy.jl")
include("backdoor.jl")
include("multisampler.jl")
Loading