Skip to content

Commit

Permalink
Merge pull request #239 from gkronber/fix_enode_memo_2
Browse files Browse the repository at this point in the history
Fix hashing and memoization of enodes (VecExpr)
  • Loading branch information
0x0f0f0f authored Sep 3, 2024
2 parents 1dc53da + 66ea780 commit 0b2e6c9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 54 deletions.
79 changes: 27 additions & 52 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,22 +224,6 @@ Returns the canonical e-class id for a given e-class.

@inline Base.getindex(g::EGraph, i::Id) = g.classes[IdKey(find(g, i))]

# function canonicalize(g::EGraph, n::VecExpr)::VecExpr
# if !v_isexpr(n)
# v_hash!(n)
# return n
# end
# l = v_arity(n)
# new_n = v_new(l)
# v_set_flag!(new_n, v_flags(n))
# v_set_head!(new_n, v_head(n))
# for i in v_children_range(n)
# @inbounds new_n[i] = find(g, n[i])
# end
# v_hash!(new_n)
# new_n
# end

function canonicalize!(g::EGraph, n::VecExpr)
v_isexpr(n) || @goto ret
for i in (VECEXPR_META_LENGTH + 1):length(n)
Expand All @@ -253,19 +237,16 @@ end

function lookup(g::EGraph, n::VecExpr)::Id
canonicalize!(g, n)
h = IdKey(v_hash(n))

haskey(g.memo, n) ? find(g, g.memo[n]) : 0
id = get(g.memo, n, zero(Id))
iszero(id) ? id : find(g, id)
end


function add_class_by_op(g::EGraph, n, eclass_id)
key = IdKey(v_signature(n))
if haskey(g.classes_by_op, key)
push!(g.classes_by_op[key], eclass_id)
else
g.classes_by_op[key] = [eclass_id]
end
vec = get!(g.classes_by_op, key, Vector{Id}())
push!(vec, eclass_id)
end

"""
Expand All @@ -274,7 +255,8 @@ Inserts an e-node in an [`EGraph`](@ref)
function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)::Id where {ExpressionType,Analysis}
canonicalize!(g, n)

haskey(g.memo, n) && return g.memo[n]
id = get(g.memo, n, zero(Id))
iszero(id) || return id

if should_copy
n = copy(n)
Expand All @@ -291,7 +273,7 @@ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)
g.memo[n] = id

add_class_by_op(g, n, id)
eclass = EClass{Analysis}(id, VecExpr[n], Pair{VecExpr,Id}[], make(g, n))
eclass = EClass{Analysis}(id, VecExpr[copy(n)], Pair{VecExpr,Id}[], make(g, n))
g.classes[IdKey(id)] = eclass
modify!(g, eclass)
push!(g.pending, n => id)
Expand Down Expand Up @@ -320,28 +302,22 @@ function addexpr!(g::EGraph, se)::Id
se isa EClass && return se.id
e = preprocess(se)

n = if isexpr(e)
args = iscall(e) ? arguments(e) : children(e)
ar = length(args)
n = v_new(ar)
v_set_flag!(n, VECEXPR_FLAG_ISTREE)
iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL)

h = iscall(e) ? operation(e) : head(e)
v_set_head!(n, add_constant!(g, h))

# get the signature from op and arity
v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar)))

for i in v_children_range(n)
@inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH])
end
n
else # constant enode
VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)])
isexpr(e) || return add!(g, VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)]), false)

args = iscall(e) ? arguments(e) : children(e)
ar = length(args)
n = v_new(ar)
v_set_flag!(n, VECEXPR_FLAG_ISTREE)
iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL)
h = iscall(e) ? operation(e) : head(e)
v_set_head!(n, add_constant!(g, h))
# get the signature from op and arity
v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar)))
for i in v_children_range(n)
@inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH])
end
id = add!(g, n, false)
return id

add!(g, n, false)
end

"""
Expand Down Expand Up @@ -431,10 +407,10 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp
while !isempty(g.pending) || !isempty(g.analysis_pending)
while !isempty(g.pending)
(node::VecExpr, eclass_id::Id) = pop!(g.pending)
node = copy(node)
canonicalize!(g, node)
if haskey(g.memo, node)
old_class_id = g.memo[node]
g.memo[node] = eclass_id
old_class_id = get!(g.memo, node, eclass_id)
if old_class_id != eclass_id
did_something = union!(g, old_class_id, eclass_id)
# TODO unique! can node dedup be moved here? compare performance
# did_something && unique!(g[eclass_id].nodes)
Expand Down Expand Up @@ -474,9 +450,8 @@ function check_memo(g::EGraph)::Bool
for (id, class) in g.classes
@assert id.val == class.id
for node in class.nodes
if haskey(test_memo, node)
old_id = test_memo[node]
test_memo[node] = id.val
old_id = get!(test_memo, node, id.val)
if old_id != id.val
@assert find(g, old_id) == find(g, id.val) "Unexpected equivalence $node $(g[find(g, id.val)].nodes) $(g[find(g, old_id)].nodes)"
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/EGraphs/uniquequeue.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ function Base.pop!(uq::UniqueQueue{T}) where {T}
v
end

Base.isempty(uq::UniqueQueue) = isempty(uq.vec)
Base.isempty(uq::UniqueQueue) = isempty(uq.vec)
2 changes: 1 addition & 1 deletion src/vecexpr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ end

"""The hash of the e-node."""
@inline v_hash(n::VecExpr)::Id = @inbounds n.data[1]
Base.hash(n::VecExpr) = v_hash(n) # IdKey not necessary here
Base.hash(n::VecExpr, h::UInt) = hash(v_hash(n), h) # IdKey not necessary here
Base.:(==)(a::VecExpr, b::VecExpr) = (@view a.data[2:end]) == (@view b.data[2:end])

"""Set e-node hash to zero."""
Expand Down

0 comments on commit 0b2e6c9

Please sign in to comment.