Skip to content

Commit

Permalink
Merge pull request #133 from mschauer/incscore
Browse files Browse the repository at this point in the history
Incrementally compute score + test
  • Loading branch information
mschauer authored Jan 11, 2024
2 parents 7459a96 + 1ee2e30 commit dbf7643
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 19 deletions.
49 changes: 32 additions & 17 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ ndown(g, total) = ne(g)
Return
"""
function count_moves(g, κ, balance, prior, score, coldness, total, dir=:both)
s1 = s2 = 0.0
s1 = s2 = Δscorevalue1 = Δscorevalue2 = 0.0
x1 = y1 = x2 = y2 = 0
T1 = Int[]
H2 = Int[]
Expand All @@ -85,14 +85,16 @@ function count_moves(g, κ, balance, prior, score, coldness, total, dir=:both)
valid = (isclique(g, NAyxT) && isblocked(g, y, x, NAyxT))
if valid
PAy = parents(g, y)
s = balance(prior(total, total+1)*exp(coldness*Δscoreinsert(score, NAyxT PAy, x, y, T)))
Δscorevalue = Δscoreinsert(score, NAyxT PAy, x, y, T)
s = balance(prior(total, total+1)*exp(coldness*Δscorevalue))
else
s = 0.0
end
@assert s >= 0
if valid && rand() > s1/(s1 + s) # sequentially draw sample
x1, y1 = x, y
T1 = T
Δscorevalue1 = Δscorevalue
end
s1 = s1 + s
end
Expand All @@ -108,25 +110,27 @@ function count_moves(g, κ, balance, prior, score, coldness, total, dir=:both)
if valid
PAy = parents(g, y)
PAy⁻ = setdiff(PAy, x)
s = balance(prior(total, total-1)*exp(coldness*Δscoredelete(score, NAyx_H PAy⁻, x, y, H)))
Δscorevalue = Δscoredelete(score, NAyx_H PAy⁻, x, y, H)
s = balance(prior(total, total-1)*exp(coldness*Δscorevalue))
else
s = 0.0
end
@assert s >= 0
if valid && rand() > s2/(s2 + s)
x2, y2 = x, y
H2 = H
Δscorevalue2 = Δscorevalue
end
s2 = s2 + s
end
end
end
end
s1, s2, (x1, y1, T1), (x2, y2, H2)
s1, s2, Δscorevalue1, Δscorevalue2, (x1, y1, T1), (x2, y2, H2)
end

function count_moves_new(g, κ, balance, prior, score, coldness, total, dir=:both)
s1 = s2 = 0.0
s1 = s2 = Δscorevalue1 = Δscorevalue2 = 0.0
x1 = y1 = x2 = y2 = 0
T1 = Int[]
H2 = Int[]
Expand All @@ -147,10 +151,12 @@ function count_moves_new(g, κ, balance, prior, score, coldness, total, dir=:bot
# to hide complexity
# or just Δscoreinsert(score, g, op)
# and op contains all necessary stuff e.g. NAyxT and so on
s = balance(prior(total, total+1)*exp(coldness*Δscoreinsert(score, NAyxT PAy, x, y, T)))
Δscorevalue = Δscoreinsert(score, NAyxT PAy, x, y, T)
s = balance(prior(total, total+1)*exp(coldness*Δscorevalue))
if rand() > s1/(s1 + s) # sequentially draw sample
x1, y1 = x, y
T1 = T
Δscorevalue1 = Δscorevalue
end
s1 = s1 + s
end
Expand All @@ -161,17 +167,19 @@ function count_moves_new(g, κ, balance, prior, score, coldness, total, dir=:bot
PAy⁻ = setdiff(PAy, x)
# I would prefer Δscoredelete(score, g, x, y, H) as above
NAyx_H = setdiff(adj_neighbors(g, x, y), H)
s = balance(prior(total, total-1)*exp(coldness*Δscoredelete(score, NAyx_H PAy⁻, x, y, H)))
Δscorevalue = Δscoredelete(score, NAyx_H PAy⁻, x, y, H)
s = balance(prior(total, total-1)*exp(coldness*Δscorevalue))
if rand() > s2/(s2 + s)
x2, y2 = x, y
H2 = H
Δscorevalue2 = Δscorevalue
end
s2 = s2 + s
end
end
end
end
s1, s2, (x1, y1, T1), (x2, y2, H2)
s1, s2, Δscorevalue1, Δscorevalue2, (x1, y1, T1), (x2, y2, H2)
end

"""
Expand All @@ -182,6 +190,8 @@ end
Run the causal zigzag algorithm starting in a cpdag `(G, t)` with `t` oriented or unoriented edges,
the balance function `balance ∈ {metropolis_balance, barker_balance, sqrt}`, `score` function (see `ges` algorithm)
coldness parameter for iterations. `σ = 1.0, ρ = 0.0` gives purely diffusive behaviour, `σ = 0.0, ρ = 1.0` gives Zig-Zag behaviour.
Returns a vector of tuples with information, each containing a graph, spent time, current direction, number of edges and the score.
"""
function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prior = (_,_)->1.0, score=UniformScore(),
coldness = 1.0, σ = 0.0, ρ = 1.0, naive=false,
Expand All @@ -191,13 +201,14 @@ function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prio
κ = n - 1
@warn "Truncate κ to "
end
gs = Vector{Tuple{typeof(g),Float64,Int,Int}}()
gs = Vector{Tuple{typeof(g),Float64,Int,Int,Float64}}()
dir = 1
global traversals = 0
global tempty = 0.0
τ = 0.0
secs = 0.0
emax = n*κ÷2
scorevalue = 0.0
@showprogress for iter in 1:iterations
τ = 0.0
total_old = total
Expand All @@ -208,17 +219,18 @@ function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prio
traversals += 1
end

Δscorevalue1 = Δscorevalue2 = 0.0
if !naive
if score isa UniformScore
s1, s2, up1, down1 = count_moves_uniform(g, κ)
total < emax && (s1 *= balance(prior(total, total+1)))
total > 0 && (s2 *= balance(prior(total, total-1)))

else
s1, s2, up1, down1 = count_moves_new(g, κ, balance, prior, score, coldness, total)
s1, s2, Δscorevalue1, Δscorevalue2, up1, down1 = count_moves_new(g, κ, balance, prior, score, coldness, total)
end
else
s1, s2, up1, down1 = count_moves(g, κ, balance, prior, score, coldness, total)
s1, s2, Δscorevalue1, Δscorevalue2, up1, down1 = count_moves(g, κ, balance, prior, score, coldness, total)
end
λbar = max(dir*(-s1 + s2), 0.0)
λrw = (s1 + s2)
Expand Down Expand Up @@ -247,7 +259,8 @@ function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prio
x, y, T = up1
@assert x != y
total == 0 && (tempty += τ)
save && push!(gs, (g, τ, dir, total))
save && push!(gs, (g, τ, dir, total, scorevalue))
scorevalue += Δscorevalue1
total += 1
secs += @elapsed begin
if !naive
Expand All @@ -264,7 +277,8 @@ function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prio
x, y, H = down1
@assert x != y
total == 0 && (tempty += τ)
save && push!(gs, (g, τ, dir, total))
save && push!(gs, (g, τ, dir, total, scorevalue))
scorevalue += Δscorevalue2
total -= 1
secs += @elapsed begin
if !naive
Expand All @@ -282,7 +296,7 @@ function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prio
x = y = 0
dir *= -1
total == 0 && (tempty += τ)
save && push!(gs, (g, τ, dir, total))
save && push!(gs, (g, τ, dir, total, scorevalue))
break
end # break
verbose && println(total_old, dir_old == 1 ? "" : "", total, " $x => $y ", round(τ, digits=8))
Expand All @@ -299,11 +313,12 @@ end
function unzipgs(gs)
graphs = first.(gs)
graph_pairs = vpairs.(graphs)
hs = map(last, gs)
scs = map(last, gs)
hs = map(x->getindex(x, 4), gs)
τs = map(x->getindex(x, 2), gs)
ws = normalize(τs, 1)
ts = cumsum(ws)
(;graphs, graph_pairs, hs, τs, ws, ts)
(;graphs, graph_pairs, hs, τs, ws, ts, scs)
end

const randcpdag = causalzigzag
const randcpdag = causalzigzag
4 changes: 2 additions & 2 deletions test/gesvsR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ using Random
@test s sb
#g2c, sc, (t1c, t2c) = ges(X; penalty, parallel=true)
@test score_R score_dag(DiGraph(d), GaussianScore(C, n, penalty)) + s
@show score_R score_dag(pdag2dag!(copy(g2)), GaussianScore(C, n, penalty))
@show score_R score_dag(pdag2dag!(copy(g3)), GaussianScore(C, n, penalty))
@test score_R score_dag(pdag2dag!(copy(g2)), GaussianScore(C, n, penalty))
@test score_R score_dag(pdag2dag!(copy(g3)), GaussianScore(C, n, penalty))

@test isempty(symdiff(vpairs(g2), vpairs(g2b)))

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ include("exact.jl")
include("operators.jl")
include("ges.jl")
include("gesvsR.jl")
include("sampler.jl")
include("gensearch.jl")
include("cpdag.jl")
include("skeleton.jl")
Expand Down
31 changes: 31 additions & 0 deletions test/sampler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using Random, CausalInference, Statistics, Test, Graphs
@testset "Zig-Zag" begin
Random.seed!(1)

N = 2000 # number of data points

# define simple linear model with added noise
x = randn(N)
v = x + randn(N)*0.25
w = x + randn(N)*0.25
z = v + w + randn(N)*0.25
s = z + randn(N)*0.25

df = (x=x, v=v, w=w, z=z, s=s)
iterations = 5_000
n = length(df) # vertices
κ = n - 1 # max degree
penalty = 2.0 # increase to get more edges in truth
Random.seed!(101)
C = cor(CausalInference.Tables.matrix(df))
score = GaussianScore(C, N, penalty)
gs = @time causalzigzag(n; score, κ, iterations)
graphs, graph_pairs, hs, τs, ws, ts, scores = CausalInference.unzipgs(gs)
posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)

# maximum aposteriori estimate
@test first(posterior).first == [1=>2, 1=>3, 2=>1, 2=>4, 3=>1, 3=>4, 4=>5]
# score of last sample
@test score_dag(pdag2dag!(copy(graphs[end])), score) scores[end] + score_dag(DiGraph(n), score)

end #testset

0 comments on commit dbf7643

Please sign in to comment.