Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multisampler #148

Open
wants to merge 38 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
df6f389
started working on multisampler
mwien Feb 29, 2024
e785e56
minor fixes
mwien Feb 29, 2024
f4cb6c5
add full multisampler
Mar 1, 2024
82ed06c
remove comment
Mar 1, 2024
ef04dea
remove printlns
Mar 1, 2024
c33bd49
Fixes
mschauer Mar 1, 2024
113866d
define penalty
mschauer Mar 1, 2024
030c087
Tuning
mschauer Mar 1, 2024
37ca082
More extensive tests
mschauer Mar 1, 2024
e69b300
rename for clarity
mwien Mar 1, 2024
27fa658
fix copy
mwien Mar 1, 2024
6698330
minor rename
mwien Mar 1, 2024
f8e7b20
remove comment
mwien Mar 1, 2024
00fea92
Fix test
mschauer Mar 1, 2024
953bbb2
Merge pull request #1 from mwien/multisampler2
mwien Mar 1, 2024
0d8f92e
I don't think we can save all graphs
mschauer Mar 2, 2024
5b12aac
fix comment
mwien Mar 2, 2024
0009661
Update src/multisampler.jl
mschauer Mar 2, 2024
35ad2e6
Ups
mschauer Mar 3, 2024
df05f9c
Make sure running init first
mschauer Mar 3, 2024
3f44570
Tricky business involving assuring that log(Pi) < 0
mschauer Mar 3, 2024
750f225
Count kills
mschauer Mar 3, 2024
aea3578
Fixes
mschauer Mar 3, 2024
c2ec4d0
Once more
mschauer Mar 3, 2024
112246d
Update src/multisampler.jl
mschauer Mar 3, 2024
a289505
Update src/multisampler.jl
mschauer Mar 3, 2024
bc4a3f0
Older versions do not support typed globals
mschauer Mar 3, 2024
6256f22
Factoring
mschauer Mar 6, 2024
cc38750
Cosmetics
mschauer Mar 6, 2024
0c5b3fe
Return a bit later
mschauer Mar 6, 2024
6fd67cf
Cosmetics
mschauer Mar 7, 2024
92d8396
Show temp
mschauer Mar 7, 2024
698dbde
Bug fix (never copy dead states)
mschauer Mar 7, 2024
7e3db9e
Need to monitor time.
mschauer Mar 7, 2024
2cec3f4
Score
mschauer Mar 7, 2024
110f294
Move opposite direction
mschauer Mar 7, 2024
67c730b
No need for baseline
mschauer Mar 7, 2024
0e11569
Allow stopping at target
mschauer Mar 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.15.1"

[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Expand All @@ -26,6 +27,17 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
TabularDisplay = "3eeacb1d-13c2-54cc-9b18-30c86af3cadb"
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"

[weakdeps]
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"

[extensions]
GraphMakieExt = "GraphMakie"
GraphRecipesExt = ["GraphRecipes", "Plots"]
TikzGraphsExt = "TikzGraphs"

[compat]
Combinatorics = "1.0"
DelimitedFiles = "1.6, 1.7, 1.8, 1.9"
Expand Down Expand Up @@ -55,11 +67,6 @@ ThreadsX = "0.1"
TikzGraphs = "1.3, 1.4"
julia = "1.6, 1.7, 1.8, 1.9, 1.10"

[extensions]
GraphMakieExt = "GraphMakie"
GraphRecipesExt = ["GraphRecipes", "Plots"]
TikzGraphsExt = "TikzGraphs"

[extras]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Expand All @@ -71,9 +78,3 @@ TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"

[targets]
test = ["Test", "StatsBase", "DelimitedFiles"]

[weakdeps]
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"
3 changes: 3 additions & 0 deletions src/CausalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ using Base.Iterators
using Memoization, LRUCache
using ThreadsX
using LinkedLists
using DataStructures

import Base: ==, show

export multisampler
export exactscorebased
export ancestors, descendants, alt_test_dsep, test_covariate_adjustment, alt_test_backdoor, find_dsep, find_min_dsep, find_covariate_adjustment, find_backdoor_adjustment, find_frontdoor_adjustment, find_min_covariate_adjustment, find_min_backdoor_adjustment, find_min_frontdoor_adjustment, list_dseps, list_covariate_adjustment, list_backdoor_adjustment, list_frontdoor_adjustment
export dsep, skeleton, gausscitest, dseporacle, partialcor
Expand Down Expand Up @@ -56,6 +58,7 @@ include("dag_sampler.jl")
include("misc2.jl")
include("exact.jl")
#include("mcs.jl")
include("multisampler.jl")

# Compatibility with the new "Package Extensions" (https://github.com/JuliaLang/julia/pull/47695)
const EXTENSIONS_SUPPORTED = isdefined(Base, :get_extension)
Expand Down
181 changes: 181 additions & 0 deletions src/multisampler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
struct Sample
g::DiGraph
τ::Float64
dir::Int8
total::Int32
scoreval::Float64
alive::Bool
end
Sample(g, nextτ, dir, total, scoreval) = Sample(g, nextτ, dir, total, scoreval, true)

struct Action
i::Int
τ::Float64
apply!::Function
args::Tuple{Vararg{Any}}
end

function expcoldness(τ, k=0.0005)
return exp(k*τ)
end

function Dexpcoldness(τ, k=0.0005)
return k*exp(k*τ)
end

function init(_, _, nextτ, g, dir, total, scoreval)
return Sample(g, nextτ, dir, total, scoreval)
end

function applyup(samplers, i, nextτ, x, y, T, Δscoreval)
prevsample = samplers[i]
g = next_CPDAG(prevsample.g, :up, x, y, T)
return samplers[i] = Sample(g, nextτ, prevsample.dir, prevsample.total+1, prevsample.scoreval + Δscoreval)
end

function applydown(samplers, i, nextτ, x, y, H, Δscoreval)
prevsample = samplers[i]
g = next_CPDAG(prevsample.g, :down, x, y, H)
return samplers[i] = Sample(g, nextτ, prevsample.dir, prevsample.total-1, prevsample.scoreval + Δscoreval)
end

function applyflip(samplers, i, nextτ)
prevsample = samplers[i]
return samplers[i] = Sample(prevsample.g, nextτ, -1*prevsample.dir, prevsample.total, prevsample.scoreval)
end

function applycopy(samplers, i, nextτ, j)
copysample = samplers[j]
s = (i == j) ? 1 : -1 # move opposite direction
return samplers[i] = Sample(copysample.g, nextτ, s*copysample.dir, copysample.total, copysample.scoreval)
end

function applykill(samplers, i, nextτ)
prevsample = samplers[i]
return samplers[i] = Sample(prevsample.g, nextτ, prevsample.dir, prevsample.total, prevsample.scoreval, false)
end

function applynothing(samplers, i, nextτ)
@assert false
sample = samplers[i]
return samplers[i] = Sample(sample.g, nextτ, sample.dir, sample.total, sample.scoreval, sample.alive)
end

# for starters without turn move

function sampleaction(samplers, i, M, balance, prior, score, maxscoreval, σ, ρ, κ, coldness, Dcoldness, threshold, keep, force)
# preprocess
prevsample = samplers[i]
prevsample.alive || return Action(i, Inf, applynothing, ())

sup, sdown, Δscorevalup, Δscorevaldown, argsup, argsdown = count_moves_new(prevsample.g, κ, balance, prior, score, coldness(prevsample.τ), prevsample.total)
# propose moves
λdir = prevsample.dir == 1 ? sup : sdown
λupdown = sup + sdown
λflip = max(prevsample.dir*(-sup + sdown), 0.0)
λterm = force*exp(ULogarithmic, 0.0)*Dcoldness(prevsample.τ) * clamp(maxscoreval - prevsample.scoreval, eps(), threshold) # TODO: prior
Δτdir = randexp()/(ρ*λdir)
Δτupdown = randexp()/(σ*λupdown)
Δτflip = randexp()/(ρ*λflip)
Δτterm = randexp()/abs(λterm)
Δτmin, a = findmin((Δτdir, Δτupdown, Δτflip, Δτterm))
A = (:dir, :updown, :flip, :term)[a]
@assert Δτmin >= 0
if :dir == A
if prevsample.dir == 1
return Action(i, prevsample.τ + Δτdir, applyup, (argsup..., Δscorevalup))
else
return Action(i, prevsample.τ + Δτdir, applydown, (argsdown..., Δscorevaldown))
end
end

if :updown == A
λup = sup
if rand() < λup/λupdown
return Action(i, prevsample.τ + Δτupdown, applyup, (argsup..., Δscorevalup))
else
return Action(i, prevsample.τ + Δτupdown, applydown, (argsdown..., Δscorevaldown))
end
end

if :flip == A
return Action(i, prevsample.τ + Δτflip, applyflip, ())
end

if :term == A
if rand() < keep
if keep < 1
j = rand(findall(s.alive for s in samplers))
else
j = rand(1:M)
end
return Action(i, prevsample.τ + Δτterm, applycopy, (j,))
else
return Action(i, prevsample.τ + Δτterm, applykill, ())
end
end

@assert false
end
# remark: chose κ = n-1 as default
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, better

function multisampler(n, G = (DiGraph(n), 0); M = 10, balance = metropolis_balance, prior = (_,_) -> 1.0, score=UniformScore(), σ = 0.0, ρ = 1.0, κ = n - 1, baseline = 0.0, iterations = min(3*n^2, 50000), schedule=(expcoldness, Dexpcoldness), target=1e10, threshold=Inf, keep=1.0, force=1.0) #, verbose = false, save = true)
if κ >= n
κ = n - 1
@warn "Truncate κ to $κ"
end
coldness, Dcoldness = schedule

initscoreval = score_dag(SimpleDiGraph(n), score)
bestgraph = DiGraph(n)
bestscore = initscoreval

# init M samplers
samplers = [Sample(G[1], 0.0, 1, G[2], initscoreval) for _ = 1:M] # pass correct initial score?!
queue = PriorityQueue{Action, Float64}()

for i = 1:M
action = sampleaction(samplers, i, M, balance, prior, score, bestscore, σ, ρ, κ, coldness, Dcoldness, threshold, keep, force)
enqueue!(queue, action, action.τ)
end

# todo: multiply iterations by M to keep passed iteration number indep of M?
iterations *= M

count = 0
particles = M
t = 0.0
β = schedule[1](t)
pr = Progress(iterations)
iter = 1
while iter <= iterations
action = dequeue!(queue)
t = action.τ
β = schedule[1](t)
β > target && break
next!(pr; showvalues = [(:M,particles), (:t, round(t, sigdigits=6)), (:score, bestscore), (:temp, round(β, sigdigits=6))])

count += (action.apply! == applycopy) || (action.apply! == applykill)
if action.apply! == applykill
particles -= 1
end
if action.apply! != applyflip # flips are free
iter += 1
end

nextsample = action.apply!(samplers, action.i, action.τ, action.args...)
particles == 0 && break

if nextsample.alive && nextsample.scoreval > bestscore
bestgraph = nextsample.g
bestscore = nextsample.scoreval
end
action = sampleaction(samplers, action.i, M, balance, prior, score, bestscore, σ, ρ, κ, coldness, Dcoldness, threshold, keep, force)
enqueue!(queue, action, action.τ)
end
finish!(pr)
killratio = count/iterations

@show particles killratio t β

return bestgraph, bestscore, [sample for sample in samplers if sample.alive]
end
123 changes: 123 additions & 0 deletions test/multisampler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
using Random, CausalInference, StatsBase, Statistics, Test, Graphs, LinearAlgebra
@testset "MultiSampler" begin
Random.seed!(1)

N = 400 # number of data points

# define simple linear model with added noise
x = randn(N)
v = x + randn(N)*0.25
w = x + randn(N)*0.25
z = v + w + randn(N)*0.25
s = z + randn(N)*0.25

df = (x=x, v=v, w=w, z=z, s=s)
iterations = 1_000
penalty = 2.0 # increase to get more edges in truth
n = length(df) # vertices
Random.seed!(101)
C = cor(CausalInference.Tables.matrix(df))
score = GaussianScore(C, N, penalty)
decay = 5.0e-5
schedule = (τ -> 1.0 + τ*decay, τ -> decay) # linear
M = 20
baseline = 0.0
bestgraph, bestscore, samplers = CausalInference.multisampler(n; M, score, baseline, schedule, iterations, keep=0.5, force=0.1)
#posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)

# maximum aposteriori estimate
MAP = [1=>2, 1=>3, 2=>1, 2=>4, 3=>1, 3=>4, 4=>5]
@test bestgraph == digraph(MAP, n)
#cm = sort(countmap(vpairs.(getfield.(samplers, :g))), byvalue=true, rev=true)
#@test first(cm).first == MAP
end #testset

@testset "MultiSampler" begin
Random.seed!(1)

N = 400 # number of data points

# define simple linear model with added noise
x = randn(N)
v = x + randn(N)*0.25
w = x + randn(N)*0.25
z = v + w + randn(N)*0.25
s = z + randn(N)*0.25

df = (x=x, v=v, w=w, z=z, s=s)
iterations = 1_000
penalty = 2.0 # increase to get more edges in truth
n = length(df) # vertices
Random.seed!(101)
C = cor(CausalInference.Tables.matrix(df))
score = GaussianScore(C, N, penalty)
decay = 3e-4

schedule = (τ -> 1.0 + τ*decay, τ -> decay) # linear
M = 20
baseline = 0.0
balance = CausalInference.sqrt_balance
threshold = Inf
bestgraph, bestscore, samplers = multisampler(n; M, ρ = 1.0, score, balance, baseline, schedule, iterations, threshold)
#posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)

# maximum aposteriori estimate
MAP = [1=>2, 1=>3, 2=>1, 2=>4, 3=>1, 3=>4, 4=>5]
@test bestgraph == digraph(MAP, n)
cm = sort(countmap(vpairs.(getfield.(samplers, :g))), byvalue=true, rev=true)
Tmin, T = extrema(getfield.(samplers, :τ))
@show Tmin T schedule[1](T)
@test first(cm).first == MAP
end

@testset "MultiSampler" begin
Random.seed!(1)
decay = 2e-5
schedule = (τ -> 0.8 + τ*decay, τ -> decay) # linear

N = 200 # number of data points

# define simple linear model with added noise
x = randn(N)
v = x + randn(N)*0.25
w = x + randn(N)*0.25
z = v + w + randn(N)*0.25
s = z + randn(N)*0.25

df = (x=x, v=v, w=w, z=z, s=s)
iterations = 880
penalty = 2.0 # increase to get more edges in truth
n = length(df) # vertices
Random.seed!(101)
C = cor(CausalInference.Tables.matrix(df))
score = GaussianScore(C, N, penalty)
M = 100
bestgraph, bestscore, samplers = multisampler(n; M, score, schedule, iterations, target=1.2)
Tmin, T = extrema(getfield.(samplers, :τ))
coldness = schedule[1](T)
@show Tmin T coldness

gs = causalzigzag(n; score, κ=n-1, ρ=10.0, coldness, iterations=iterations*100)
graphs, graph_pairs, hs, τs, ws, ts, scores = CausalInference.unzipgs(gs)
posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)


# maximum aposteriori estimate
MAP = [1=>2, 1=>3, 2=>1, 2=>4, 3=>1, 3=>4, 4=>5]
@test bestgraph == digraph(MAP, n)
cm = sort((proportionmap(vpairs.(getfield.(samplers, :g)))), byvalue=true, rev=true)
@test first(cm).first == MAP
logΠ = map(g->score_dag(pdag2dag!(digraph(g, n)), score), collect(keys(cm)))
Π = normalize(exp.(coldness*(logΠ .- maximum(logΠ))), 1)
Πhat = normalize(collect(values(cm)), 1)

display([Π Πhat])
s = 0.0
for (i, k) in enumerate(keys(cm))
s += get(posterior, k, 0.0)
#@show cm[k] Π[i]
end
@show s
@test s > 0.99
@test norm(collect(values(cm)) - Π) < 0.02
end #testset
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ include("witness.jl")
include("fci.jl")
include("klentropy.jl")
include("backdoor.jl")
include("multisampler.jl")
Loading