From 84485577a0ed8e3f7b1461df6d5f26d7b3cfbbf9 Mon Sep 17 00:00:00 2001 From: Tim Whiting Date: Tue, 18 Jun 2024 13:53:02 -0600 Subject: [PATCH] match merging optimization --- out.kk | 183 +++++++++++++++++++++++++++++++++++++++ src/Compile/TypeCheck.hs | 8 +- src/Core/MatchMerge.hs | 115 +++++++++++++++--------- src/Core/Pretty.hs | 4 +- test/cgen/match-merge.kk | 14 ++- 5 files changed, 276 insertions(+), 48 deletions(-) create mode 100644 out.kk diff --git a/out.kk b/out.kk new file mode 100644 index 000000000..c5cef5567 --- /dev/null +++ b/out.kk @@ -0,0 +1,183 @@ +load : std/core +load : std/core/types +load : std/core/undiv +load : std/core/order +load : std/core/exn +load : std/core/unsafe +load : std/core/bool +load : std/core/hnd +load : std/core/int +load : std/core/char +load : std/core/string +load : std/core/sslice +load : std/core/vector +load : std/core/list +load : std/core/tuple +load : std/core/maybe +load : std/core/either +load : std/core/show +load : std/core/delayed +load : std/core/debug +load : std/core/console +parse : .../test/cgen/match-merge.kk +check : test/cgen/match-merge +module test/cgen/match-merge +import std/core/types = std/core/types pub = ""; +import std/core/hnd = std/core/hnd pub = ""; +import std/core/exn = std/core/exn pub = ""; +import std/core/bool = std/core/bool pub = ""; +import std/core/order = std/core/order pub = ""; +import std/core/char = std/core/char pub = ""; +import std/core/int = std/core/int pub = ""; +import std/core/vector = std/core/vector pub = ""; +import std/core/string = std/core/string pub = ""; +import std/core/sslice = std/core/sslice pub = ""; +import std/core/list = std/core/list pub = ""; +import std/core/maybe = std/core/maybe pub = ""; +import std/core/either = std/core/either pub = ""; +import std/core/tuple = std/core/tuple pub = ""; +import std/core/show = std/core/show pub = ""; +import std/core/debug = std/core/debug pub = ""; +import std/core/delayed = std/core/delayed pub = ""; +import std/core/console = std/core/console pub = ""; +import std/core = std/core = ""; +pub fun implicit-error-fallthrough : (e : list) -> () + = fn<>(e: list){ + match (e) { + (std/core/types/Cons(((@skip std/core/types/@Box((a: int)) : @Box ) as @box-x0: @Box), ((std/core/types/Cons(((@skip std/core/types/@Box((b: int)) : @Box + ) as @box-x1: @Box), (@case306: list)) : list ) as @pat@0: list)) : list ) + -> val _ : () + = (match ((std/core/types/@is-unique(e))) { + (std/core/types/True() : bool ) + -> val _ : () + = val _ : () + = (match ((std/core/types/@is-unique(@pat@0))) { + (std/core/types/True() : bool ) + -> val _ : () + = std/core/types/Unit; + std/core/types/@free(@pat@0); + _ + -> val _ : () + = val _ : list + = std/core/types/@dup(@case306); + val _ : int + = std/core/types/@dup(b); + std/core/types/Unit; + val _ : () + = std/core/types/@dec-ref(@pat@0); + std/core/types/Unit; + }); + std/core/types/Unit; + std/core/types/@free(e); + _ + -> val _ : () + = val _ : list + = std/core/types/@dup(@case306); + val _ : int + = std/core/types/@dup(a); + val _ : int + = std/core/types/@dup(b); + std/core/types/Unit; + val _ : () + = std/core/types/@dec-ref(e); + std/core/types/Unit; + }); + std/core/console/printsln((std/core/int/show((match (b, @case306) { + 1, (std/core/types/Cons(((@skip std/core/types/@Box((c: int)) : @Box ) as @box-x2: @Box), ((std/core/types/Nil() : list ) as @pat@3: list)) + : list ) + -> val _ : () + = std/core/types/@drop(b); + val _ : () + = (match ((std/core/types/@is-unique(@case306))) { + (std/core/types/True() : bool ) + -> val _ : () + = std/core/types/Unit; + std/core/types/@free(@case306); + _ + -> val _ : () + = val _ : int + = std/core/types/@dup(c); + std/core/types/Unit; + val _ : () + = std/core/types/@dec-ref(@case306); + std/core/types/Unit; + }); + std/core/int/int-add(a, c); + _, _ + -> val _ : () + = std/core/types/@drop(a); + val _ : () + = std/core/types/@drop(@case306); + b; + })))); + (std/core/types/Nil() : (list) ) + -> std/core/console/printsln("Nothing"); + _ + -> val _ : () + = std/core/types/@drop(e); + std/core/types/@unbox((std/core/exn/error-pattern("test/cgen/match-merge(2, 3)", "implicit-error-fallthrough"))); + }; + }; +pub fun main : () -> console/console () + = fn(){ + std/core/types/@unbox(val @b-x11@22 : ((m : hnd/marker, hnd/ev, x : exception) -> console/console a) + = (fn(m: hnd/marker, @_wildcard-x638_16: hnd/ev, x: exception){ + val _ : () + = (std/core/types/@drop(@_wildcard-x638_16, (std/core/types/@make-int32(3)))); + val _ : () + = (std/core/types/@drop(x)); + (std/core/hnd/yield-to-final(m, (fn(@b-x9: (hnd/resume-result<3004,3007>) -> 3006 3007){ + val @_wildcard-x638_45@25 : ((hnd/resume-result) -> console/console ()) + = (fn(@b-x10: hnd/resume-result<1004,()>){ + (std/core/types/@unbox((@b-x9(@b-x10)))); + }); + val _ : () + = (std/core/types/@drop(@_wildcard-x638_45@25)); + (std/core/types/@box((std/core/console/printsln("Error")))); + }))); + }); + (std/core/exn/@handle-exn((std/core/exn/@Hnd-exn(0, (std/core/hnd/Clause1((fn(@b-x12: hnd/marker<1018,1019>, @b-x13: hnd/ev<1017>, @b-x14: 1015){ + (@b-x11@22(@b-x12, @b-x13, (std/core/types/@unbox(@b-x14)))); + }))))), (fn(@b-x18: 3002){ + val @x@26 : () + = (std/core/types/@unbox(@b-x18)); + (std/core/types/@box(@x@26)); + }), (fn<>(){ + (std/core/types/@box((std/core/console/printsln((std/core/int/show(2)))))); + })))); + }; +module test/cgen/match-merge/@main +import std/core/types = std/core/types pub = ""; +import std/core/hnd = std/core/hnd pub = ""; +import std/core/exn = std/core/exn pub = ""; +import std/core/bool = std/core/bool pub = ""; +import std/core/order = std/core/order pub = ""; +import std/core/char = std/core/char pub = ""; +import std/core/int = std/core/int pub = ""; +import std/core/vector = std/core/vector pub = ""; +import std/core/string = std/core/string pub = ""; +import std/core/sslice = std/core/sslice pub = ""; +import std/core/list = std/core/list pub = ""; +import std/core/maybe = std/core/maybe pub = ""; +import std/core/either = std/core/either pub = ""; +import std/core/tuple = std/core/tuple pub = ""; +import std/core/show = std/core/show pub = ""; +import std/core/debug = std/core/debug pub = ""; +import std/core/delayed = std/core/delayed pub = ""; +import std/core/console = std/core/console pub = ""; +import std/core = std/core = ""; +import test/cgen/match-merge = test/cgen/match-merge = ""; +// Stateful functions can manipulate heap `:h` using allocations, reads and writes. +local alias st<(h :: H)> :: H -> E = <(read :: H -> X)<(h :: H)>,(write :: H -> X)<(h :: H)>,(alloc :: H -> X)<(h :: H)>> = 1; +pub fun @expr : () -> console/console () + = fn(){ + test/cgen/match-merge/main(); + }; +pub fun @main : () -> ,console/console,div,fsys,ndet,net,ui> () + = fn(){ + test/cgen/match-merge/main(); + }; +linking : test/cgen/match-merge/@main +compile : kklib from: /home/tim/koka/kklib +created : .koka/v3.1.2/gcc-debug-612e08/test_cgen_match_dash_merge__main +2 diff --git a/src/Compile/TypeCheck.hs b/src/Compile/TypeCheck.hs index 162de438e..8afb28559 100644 --- a/src/Compile/TypeCheck.hs +++ b/src/Compile/TypeCheck.hs @@ -26,7 +26,7 @@ import Syntax.RangeMap import Syntax.Syntax import Static.FixityResolve( fixitiesCompose, fixitiesNew, fixityResolve ) import Static.BindingGroups( bindingGroups ) -import Core.Pretty( prettyDef ) +import Core.Pretty( prettyDef, prettyCore ) import Core.CoreVar( extractDepsFromSignatures ) import Core.Check( checkCore ) @@ -117,9 +117,11 @@ typeCheck flags defs coreImports program0 let borrowed = borrowedExtendICore (coreProgram{ Core.coreProgDefs = coreDefs }) (defsBorrowed defs) checkFBIP penv (platform flags) newtypes borrowed gamma matchMergeDefs - -- coreDefs <- Core.getCoreDefs - -- let coreDoc2 = Core.Pretty.prettyCore (prettyEnvFromFlags flags){ coreIface = False, coreShowDef = True } (C CDefault) [] + -- trace "Finished match merging" $ return () + -- coreDefs <- Core.getCoreDefs + -- let coreDoc2 = Core.Pretty.prettyCore (prettyEnvFromFlags flags){ coreIface = False, coreShowDef = True } (C CDefault) [] -- (coreProgram{ Core.coreProgDefs = coreDefs }) + -- trace (show coreDoc2) $ return () -- initial simplify let ndebug = optimize flags > 0 diff --git a/src/Core/MatchMerge.hs b/src/Core/MatchMerge.hs index 1171db7e4..210763ef7 100644 --- a/src/Core/MatchMerge.hs +++ b/src/Core/MatchMerge.hs @@ -32,7 +32,7 @@ trace s x = Lib.Trace.trace s x -matchMergeDefs :: CorePhase () +matchMergeDefs :: CorePhase b () matchMergeDefs = liftCorePhaseUniq $ \uniq defs -> runUnique uniq $ matchMergeDefGroups defs @@ -63,9 +63,8 @@ matchMergeExpr body Case exprs branches -> do (branches', changed) <- mergeBranches branches - -- if changed then trace ("matchMergeExpr:\n" ++ show branches ++ "\nrewrote to: \n" ++ show branches' ++ "\n") + -- if changed then trace ("matchMergeExpr:\n" ++ show branches ++ "\nrewrote to: \n" ++ show branches' ++ "\n") (return ()) else return () return $ Case exprs branches' - -- else return $ Case exprs branches _ -> return body -- Takes a set of branches, and transforms them by merging branches that have some shared superstructure. @@ -79,22 +78,22 @@ mergeBranches branches@(b@(Branch [pat@PatCon{patConPatterns=ps}] _): rst) splitted <- splitBranchConstructors b rst case splitted of -- Single branch, return itself unchanged - ([b], [], err, pat') -> return ([b], False) + ([b], [], err, tns, pat') -> return ([b], False) -- Single branch shares structure, rest do not, merge the rest and append - ([b], rst, err, pat') -> + ([b], rst, err, tns, pat') -> do (rest, v) <- mergeBranches rst - return (b:rest, v) + return (b : rest, v) -- Multiple branches share structure - (bs, rst, err, pat') -> + (bs, rst, err, tns, pat') -> do - trace ("mergeBranches:\n" ++ intercalate "\n" (map (show . branchPatterns) bs) ++ "\n with common superstructure:\n" ++ show pat' ++ "\n\n") $ return () + -- trace ("mergeBranches:\n" ++ " has error? " ++ show err ++ "\n" ++ intercalate "\n" (map (show . branchPatterns) bs) ++ "\n with common superstructure:\n" ++ show pat' ++ "\n\n") $ return () let - vars' = collectPatsVars pat' -- Collect the variables introduced by the shared structure + vars' = tns -- collectPatsVars pat' -- Collect the variables introduced by the shared structure varsMatch = [Var tn InfoNone | tn <- vars'] -- Create expressions for those vars -- Get rid of the common superstructure from the branches that share superstructure -- Also add the implicit error branch if it exists - subBranches = map (stripOuterConstructors pat') bs ++ maybeToList err + subBranches = (map (stripOuterConstructors pat') bs) ++ maybeToList err (newSubBranches, innerV) <- mergeBranches subBranches (rest, v) <- mergeBranches rst -- Merge the branches that do not share structure with the current set -- Replace the set of common branches, with a single branch that matches on the shared superstructure, and delegates @@ -123,15 +122,15 @@ collectVars p -- - a possible (implicit error) branch found -- - and the pattern that unifies the matched branches -- Greedily in order processing, The first branch is the branch under consideration and the others are the next parameter -splitBranchConstructors :: Branch -> [Branch] -> Unique ([Branch], [Branch], Maybe Branch, [Pattern]) +splitBranchConstructors :: Branch -> [Branch] -> Unique ([Branch], [Branch], Maybe Branch, [TName], [Pattern]) splitBranchConstructors b@(Branch ps _) branches = case branches of -- Only one branch, it matches it's own pattern - [] -> return ([b], [], if isErrorBranch b then Just b else Nothing, ps) + [] -> return ([b], [], if isErrorBranch b then Just b else Nothing, [], ps) b'@(Branch ps' _):bs -> do -- First do the rest other than b' - (bs', bs2', e, accP) <- splitBranchConstructors b bs + (bs', bs2', e, restTns, accP) <- splitBranchConstructors b bs -- keep track of error branch to propagate into sub branches let newError = case (e, b') of (Just e, _) -> Just e -- implicit error is in the rest of the branches @@ -139,14 +138,15 @@ splitBranchConstructors b@(Branch ps _) branches = _ -> Nothing -- no error branch -- Acumulated pattern and p' patNew <- zipWithM patternsMatch accP ps' - if not $ isSimpleMatches patNew then + let (newVars, patNews) = unzip patNew + if not $ isSimpleMatches patNews then -- Restrict the pattern to the smallest that matches multiple branches -- Add the new branch to the list of branches that match partially - trace ("splitConstructors:\n" ++ show accP ++ "\nand\n" ++ show ps' ++ "\n have common superstructure:\n" ++ show patNew ++ "\n\n") - $ return (bs' ++ [b'], bs2', newError, patNew) + -- trace ("splitConstructors:\n" ++ show accP ++ "\nand\n" ++ show ps' ++ "\n have common superstructure:\n" ++ show patNew ++ "\n\n") $ + return (bs' ++ [b'], bs2', newError, concat newVars, patNews) -- Didn't match the current branch, keep the old pattern -- Add the new branch to the list of branches that don't match any subpattern - else return (bs', b':bs2', newError, accP) + else return (bs', b':bs2', newError, restTns, accP) isPatWild :: Pattern -> Bool isPatWild PatWild = True @@ -164,33 +164,45 @@ isSimpleMatch p = -- Checks to see if the branch is an error branch isErrorBranch:: Branch -> Bool -isErrorBranch (Branch _ [Guard _ (App (TypeApp (Var name _) _) _)]) = getName name == namePatternMatchError +isErrorBranch (Branch _ [Guard _ e]) = isErrorExpr e isErrorBranch _ = False +isErrorExpr :: Expr -> Bool +isErrorExpr (TypeApp (Var name _) _) = + let nm = getName name in + nm == namePatternMatchError +isErrorExpr (App (App (TypeApp (Var name _) _) [e]) args) = + getName name == nameEffectOpen && isErrorExpr e +isErrorExpr (App e args) = isErrorExpr e +isErrorExpr (TypeApp e args) = isErrorExpr e +isErrorExpr e = + False + generalErrorBranch:: Branch -> Branch generalErrorBranch b@(Branch p g) | isErrorBranch b = Branch [PatWild] g generalErrorBranch b = b --- Returns largest common pattern superstructure, with variables added where needed -patternsMatch :: Pattern -> Pattern -> Unique Pattern +-- Returns largest common pattern superstructure, with variables added where needed, and the distinguishing variables returned +patternsMatch :: Pattern -> Pattern -> Unique ([TName], Pattern) patternsMatch p p' = case (p, p') of (PatLit l1, PatLit l2) -> - if l1 == l2 then return p -- Literals that match, just match the literal + if l1 == l2 then return ([], p) -- Literals that match, just match the literal else do -- Match a variable of the literal's type name <- newVarName - return $ PatVar (TName name (typeOf l1)) PatWild + let tn = TName name (typeOf l1) + return ([tn], PatVar tn PatWild) (PatVar tn1 v1, PatVar tn2 v2) | tn1 == tn2 -> do -- Same pattern variable, reuse the variable name, but find common substructure - sub <- patternsMatch v1 v2 - return $ PatVar tn1 sub + (tns, sub) <- patternsMatch v1 v2 + return (tns, PatVar tn1 sub) (PatVar tn1 v1, PatVar tn2 v2) -> do -- Variables that don't match name, but (should match types because of type checking) -- Create a common name to match for name <- newVarName - sub <- patternsMatch v1 v2 - return $ PatVar (TName name (typeOf tn1)) sub - (PatWild, PatWild) -> return PatWild -- Wilds match trivially + (tns, sub) <- patternsMatch v1 v2 + return (tns, PatVar (TName name (typeOf tn1)) sub) + (PatWild, PatWild) -> return ([], PatWild) -- Wilds match trivially (PatCon name1 patterns1 cr targs1 exists1 res1 ci sk, PatCon name2 patterns2 _ targs2 exists2 res2 _ _) -> if -- Same constructor (name, and types) -- types should match due to type checking, but names could differ name1 == name2 && @@ -199,27 +211,50 @@ patternsMatch p p' res1 == res2 then do -- Same constructor, match substructure - subs <- zipWithM patternsMatch patterns1 patterns2 - return $ PatCon name1 subs cr targs1 exists1 res1 ci sk + res <- zipWithM patternsMatch patterns1 patterns2 + let (subs, pats) = unzip res + return (concat subs, PatCon name1 pats cr targs1 exists1 res1 ci sk) else do name <- newVarName - return $ PatVar (TName name res1) PatWild -- Different constructors, no match + let tn = TName name res1 + return ([tn], PatVar tn PatWild) -- Different constructors, no match + (PatVar tn PatWild, PatWild) -> do + return ([], PatVar tn PatWild) + (PatWild, PatVar tn PatWild) -> do + return ([], PatVar tn PatWild) + (PatVar tn PatWild, _) -> do + return ([tn], PatVar tn PatWild) + (_, PatVar tn PatWild) -> do + return ([tn], PatVar tn PatWild) (PatVar tn pat, _) -> do - sub <- patternsMatch pat p' - return $ PatVar tn sub + (tns, sub) <- patternsMatch pat p' + return (tns, PatVar tn sub) (_, PatVar tn pat) -> do - sub <- patternsMatch p pat - return $ PatVar tn sub - (_, PatWild) -> return PatWild - (PatWild, _) -> return PatWild + (tns, sub) <- patternsMatch p pat + return (tns, PatVar tn sub) + -- Double sided wilds already handled so we can safely request the type, as well as one sided vars + (_, PatWild) -> do + name <- newVarName + let tn = TName name (patternType p) + return ([tn], PatVar tn PatWild) + (PatWild, _) -> do + name <- newVarName + let tn = TName name (patternType p') + return ([tn], PatVar tn PatWild) (_, _) -> failure $ "patternsMatch: " ++ show p ++ " " ++ show p' ++ " " where newVarName = uniqueId "case" >>= (\id -> return $ newHiddenName ("case" ++ show id)) +patternType :: Pattern -> Type +patternType p = case p of + PatLit l -> typeOf l + PatVar tn _ -> typeOf tn + PatCon tn _ _ targs _ resTp _ _ -> resTp + -- Strip the outer constructors and propagate variable substitution into branch expressions stripOuterConstructors :: [Pattern] -> Branch -> Branch -stripOuterConstructors templates (Branch pts exprs) - = trace ("Using template\n" ++ show templates ++ "\nand outer subpattern from\n" ++ show pts ++ "\ngot:\n" ++ show (patNew, replaceMap) ++ "\n") $ - Branch (concatMap (fromMaybe [PatWild]) patNew) $ map replaceInGuard exprs +stripOuterConstructors templates (Branch pts exprs) = + -- trace ("Using template\n" ++ show templates ++ "\nand outer subpattern from\n" ++ show pts ++ "\ngot:\n" ++ show (patNew, replaceMap) ++ "\n") $ + Branch (concatMap (fromMaybe [PatWild]) patNew) $ map replaceInGuard exprs where replaceInGuard (Guard tst expr) = Guard (rewriteBottomUp replaceInExpr tst) (rewriteBottomUp replaceInExpr expr) @@ -250,8 +285,8 @@ getReplaceMap template p' (patterns', replaceMaps) = unzip res replaceMap = concat replaceMaps in (Just (concatMap (fromMaybe []) patterns'), replaceMap) + (PatVar tn PatWild, PatWild) -> (Nothing, []) (PatVar tn PatWild, pat2) -> (Just [pat2], []) (PatVar tn pat, pat2) -> getReplaceMap pat pat2 - (pat, PatVar tn pat2) -> getReplaceMap pat pat2 (PatWild, pat2) -> (Just [pat2], []) _ -> failure $ "\ngetReplaceMap:\n" ++ show template ++ "\n:" ++ show p' ++ "\n" diff --git a/src/Core/Pretty.hs b/src/Core/Pretty.hs index 55f26cec3..b9f035734 100644 --- a/src/Core/Pretty.hs +++ b/src/Core/Pretty.hs @@ -431,9 +431,9 @@ prettyGuard env (Guard test expr) prettyPatterns :: Env -> [Pattern] -> (Env,[Doc]) prettyPatterns env pats - = foldl f (env,[]) pats + = foldr f (env,[]) pats where - f (env,docs) pat = let (env',doc) = prettyPattern env pat + f pat (env,docs) = let (env',doc) = prettyPattern env pat in (env',doc:docs) prettyPatternType (pat,tp) (env,docs) diff --git a/test/cgen/match-merge.kk b/test/cgen/match-merge.kk index 8353be7d5..b80139d7d 100644 --- a/test/cgen/match-merge.kk +++ b/test/cgen/match-merge.kk @@ -1,5 +1,6 @@ fun implicit-error-fallthrough(e: list) match e + Cons(_, Cons(2, Cons(3, Nil))) -> 100.show.println Cons(a, Cons(1, Cons(c, Nil))) -> (a + c).show.println Cons(_, Cons(b, _)) -> b.show.println Nil -> "Nothing".println @@ -13,7 +14,14 @@ fun reflow(e: list) fun main() try { - implicit-error-fallthrough([1, 2, 3]) - reflow([1, 2, 3]) - } fn(err) { println("Error") } + implicit-error-fallthrough([]) + implicit-error-fallthrough([1,2]) + implicit-error-fallthrough([2,1,2]) + implicit-error-fallthrough([1]) + } fn(err) { println(err.message) } + + reflow([]) + reflow([1,2]) + reflow([2,1,2]) + reflow([1]) \ No newline at end of file