Skip to content

Commit

Permalink
Merge pull request #204 from JuliaSymbolics/ale/3.0-finalize-static-m…
Browse files Browse the repository at this point in the history
…atchers

Compiled Pattern Matching
  • Loading branch information
0x0f0f0f authored May 8, 2024
2 parents f1b020d + 414bcbc commit 1e2ae88
Show file tree
Hide file tree
Showing 15 changed files with 413 additions and 194 deletions.
14 changes: 7 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[weakdeps]
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"

[extensions]
Plotting = ["GraphViz"]

[compat]
AutoHashEquals = "2.1.0"
DocStringExtensions = "0.8, 0.9"
Reexport = "0.2, 1"
TimerOutputs = "0.5"
TermInterface = "0.4.1"
TimerOutputs = "0.5"
julia = "1.9"

[extras]
Expand All @@ -26,9 +32,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Documenter", "SafeTestsets", "Literate"]

[weakdeps]
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"

[extensions]
Plotting = ["GraphViz"]
4 changes: 3 additions & 1 deletion examples/propositional_logic_theory.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# # Rewriting

using Metatheory, TermInterface

fold = @theory p q begin
(p::Bool == q::Bool) => (p == q)
(p::Bool || q::Bool) => (p || q)
Expand Down Expand Up @@ -74,7 +76,7 @@ function prove(
params.goal = (g::EGraph) -> in_same_class(g, ids...)
saturate!(g, t, params)
ex = extract!(g, astsize)
if !Metatheory.isexpr(ex)
if !TermInterface.isexpr(ex)
return ex
end
if hash(ex) hist
Expand Down
2 changes: 1 addition & 1 deletion src/EGraphs/EGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Metatheory.Patterns
using Metatheory.Rules
using Metatheory.VecExprModule

using Metatheory: alwaystrue, cleanast, UNDEF_ID_VEC, should_quote_operation, OptBuffer
using Metatheory: alwaystrue, cleanast, UNDEF_ID_VEC, maybe_quote_operation, OptBuffer

import Metatheory: to_expr

Expand Down
2 changes: 1 addition & 1 deletion src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ function addexpr!(g::EGraph, se)::Id
v_set_head!(n, add_constant!(g, h))

# get the signature from op and arity
v_set_signature!(n, hash(should_quote_operation(h) ? nameof(h) : h, hash(ar)))
v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar)))

for i in v_children_range(n)
@inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH])
Expand Down
6 changes: 3 additions & 3 deletions src/EGraphs/saturation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ function eqsat_search!(

if rule isa BidirRule
for i in ids_left
n_matches += rule.ematcher_new_left!(g, rule_idx, i, rule.ematcher_stack, ematch_buffer)
n_matches += rule.ematcher_new_left!(g, rule_idx, i, rule.stack, ematch_buffer)
end
for i in ids_right
n_matches += rule.ematcher_new_right!(g, rule_idx, i, rule.ematcher_stack, ematch_buffer)
n_matches += rule.ematcher_new_right!(g, rule_idx, i, rule.stack, ematch_buffer)
end
else
for i in ids_left
n_matches += rule.ematcher!(g, rule_idx, i, rule.ematcher_stack, ematch_buffer)
n_matches += rule.ematcher!(g, rule_idx, i, rule.stack, ematch_buffer)
end
end
n_matches - prev_matches > 0 && @debug "Rule $rule_idx: $rule produced $(n_matches - prev_matches) matches"
Expand Down
12 changes: 7 additions & 5 deletions src/Metatheory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ using Reexport
function to_expr end

# TODO: document
function should_quote_operation end
should_quote_operation(::Function) = true
should_quote_operation(x) = false
Base.@inline maybe_quote_operation(x::Union{Function,DataType}) = nameof(x)
Base.@inline maybe_quote_operation(x) = x

include("docstrings.jl")

Expand All @@ -22,8 +21,7 @@ export OptBuffer

const UNDEF_ID_VEC = Vector{Id}(undef, 0)

using TermInterface
using TermInterface: isexpr
@reexport using TermInterface

"""
@matchable struct Foo fields... end [HeadType]
Expand Down Expand Up @@ -64,6 +62,10 @@ export @timer
include("Patterns.jl")
@reexport using .Patterns

include("match_compiler.jl")
export match_compile


include("ematch_compiler.jl")
export ematch_compile

Expand Down
6 changes: 4 additions & 2 deletions src/Patterns.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Patterns

using Metatheory: cleanast, alwaystrue, should_quote_operation
using Metatheory: cleanast, alwaystrue, maybe_quote_operation
using AutoHashEquals
using TermInterface
using Metatheory.VecExprModule
Expand Down Expand Up @@ -92,7 +92,9 @@ struct PatExpr <: AbstractPat
n::VecExpr
function PatExpr(iscall, op, args::Vector)
op_hash = hash(op)
qop, qop_hash = should_quote_operation(op) ? (nameof(op), hash(nameof(op))) : (op, op_hash)
# Should call `nameof` on op if Function or DataType. Identity otherwise
qop = maybe_quote_operation(op)
qop_hash = hash(qop)
ar = length(args)
signature = hash(qop, hash(ar))

Expand Down
31 changes: 12 additions & 19 deletions src/Rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,25 @@ variables.
matcher
patvars::Vector{Symbol}
ematcher!
ematcher_stack::OptBuffer{UInt16}
stack::OptBuffer{UInt16}
end

function RewriteRule(l, r, ematcher!)
function RewriteRule(l, r, matcher!, ematcher!)
pvars = patvars(l) patvars(r)
# sort!(pvars)
setdebrujin!(l, pvars)
setdebrujin!(r, pvars)
RewriteRule(l, r, matcher(l), pvars, ematcher!, OptBuffer{UInt16}(STACK_SIZE))
RewriteRule(l, r, matcher!, pvars, ematcher!, OptBuffer{UInt16}(STACK_SIZE))
end

Base.show(io::IO, r::RewriteRule) = print(io, :($(r.left) --> $(r.right)))


function (r::RewriteRule)(term)
# n == 1 means that exactly one term of the input (term,) was matched
success(bindings, n) = n == 1 ? instantiate(term, r.right, bindings) : nothing

success(pvars...) = instantiate(term, r.right, pvars)
try
r.matcher(success, (term,), EMPTY_DICT)
r.matcher(term, success, r.stack)
catch err
rethrow(err)
throw(RuleRewriteError(r, term, err))
Expand All @@ -98,7 +97,7 @@ with the EGraphs backend.
patvars::Vector{Symbol}
ematcher_new_left!
ematcher_new_right!
ematcher_stack::OptBuffer{UInt16}
stack::OptBuffer{UInt16}
end

function EqualityRule(l, r, ematcher_new_left!, ematcher_new_right!)
Expand Down Expand Up @@ -136,7 +135,7 @@ backend. If two terms, corresponding to the left and right hand side of an
patvars::Vector{Symbol}
ematcher_new_left!
ematcher_new_right!
ematcher_stack::OptBuffer{UInt16}
stack::OptBuffer{UInt16}
end

function UnequalRule(l, r, ematcher_new_left!, ematcher_new_right!)
Expand Down Expand Up @@ -177,30 +176,24 @@ Dynamic rule
matcher
patvars::Vector{Symbol} # useful set of pattern variables
ematcher!
ematcher_stack::OptBuffer{UInt16}
stack::OptBuffer{UInt16}
end

function DynamicRule(l, r::Function, ematcher!, rhs_code = nothing)
function DynamicRule(l, r::Function, matcher, ematcher!, rhs_code = nothing)
pvars = patvars(l)
setdebrujin!(l, pvars)
isnothing(rhs_code) && (rhs_code = repr(rhs_code))

DynamicRule(l, r, rhs_code, matcher(l), pvars, ematcher!, OptBuffer{UInt16}(512))
DynamicRule(l, r, rhs_code, matcher, pvars, ematcher!, OptBuffer{UInt16}(512))
end


Base.show(io::IO, r::DynamicRule) = print(io, :($(r.left) => $(r.rhs_code)))

function (r::DynamicRule)(term)
# n == 1 means that exactly one term of the input (term,) was matched
success(bindings, n) =
if n == 1
bvals = [bindings[i] for i in 1:length(r.patvars)]
return r.rhs_fun(term, nothing, bvals...)
end

success(bindings...) = r.rhs_fun(term, nothing, bindings...)
try
return r.matcher(success, (term,), EMPTY_DICT)
return r.matcher(term, success, r.stack)
catch err
throw(RuleRewriteError(r, term, err))
end
Expand Down
31 changes: 21 additions & 10 deletions src/Syntax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Metatheory.Patterns
using Metatheory.Rules
using TermInterface

using Metatheory: alwaystrue, cleanast, ematch_compile
using Metatheory: alwaystrue, cleanast, ematch_compile, match_compile

export @rule
export @theory
Expand Down Expand Up @@ -373,14 +373,17 @@ macro rule(args...)
end
ematcher_left_expr = esc(ematch_compile(lhs, ppvars, 1))

matcher_left_expr = match_compile(lhs, pvars)


if RuleType == DynamicRule
rhs_rewritten = rewrite_rhs(r)
rhs_consequent = makeconsequent(rhs_rewritten)
params = Expr(:tuple, :_lhs_expr, :_egraph, pvars...)
rhs = :($(esc(params)) -> $(esc(rhs_consequent)))
return quote
$(__source__)
DynamicRule($lhs, $rhs, $ematcher_left_expr, $(QuoteNode(rhs_consequent)))
DynamicRule($lhs, $rhs, $matcher_left_expr, $ematcher_left_expr, $(QuoteNode(rhs_consequent)))
end
end

Expand All @@ -393,7 +396,7 @@ macro rule(args...)

quote
$(__source__)
($RuleType)($lhs, $rhs, $ematcher_left_expr)
($RuleType)($lhs, $rhs, $matcher_left_expr, $ematcher_left_expr)
end
end

Expand Down Expand Up @@ -470,16 +473,24 @@ macro capture(args...)

pvars = Symbol[]
lhs = makepattern(lhs, pvars, slots, __module__)
bind = Expr(
:block,
map(key -> :($(esc(key)) = getindex(__MATCHES__, findfirst((==)($(QuoteNode(key))), $pvars))), pvars)...,
)
bind_exprs = Expr[]

for key in pvars
idx = findfirst((==)(key), pvars)
push!(bind_exprs, :($(esc(key)) = __MATCHES__[$idx]))
end

setdebrujin!(lhs, pvars)

matcher_left_expr = match_compile(lhs, pvars)


ret = quote
$(__source__)
rule = DynamicRule($lhs, (_lhs_expr, _egraph, pvars...) -> pvars, (x...) -> nothing)
rule = DynamicRule($lhs, (_lhs_expr, _egraph, pvars...) -> pvars, $matcher_left_expr, nothing)
__MATCHES__ = rule($(esc(ex)))
if __MATCHES__ !== nothing
$bind
if !isnothing(__MATCHES__)
$(bind_exprs...)
true
else
false
Expand Down
Loading

0 comments on commit 1e2ae88

Please sign in to comment.