From c13d77559b66059bcad2ae15fb59cc7e8afdce54 Mon Sep 17 00:00:00 2001 From: Tim Whiting Date: Tue, 24 Oct 2023 23:02:18 -0600 Subject: [PATCH] start debugging issues with match merge --- src/Core/MatchMerge.hs | 152 ++++++++++++++++++++++++----------------- 1 file changed, 90 insertions(+), 62 deletions(-) diff --git a/src/Core/MatchMerge.hs b/src/Core/MatchMerge.hs index 96ff6c3ac..1171db7e4 100644 --- a/src/Core/MatchMerge.hs +++ b/src/Core/MatchMerge.hs @@ -26,6 +26,7 @@ import qualified Core.Core as Core import Core.Pretty import Core.CoreVar import Core.Uniquefy +import Data.List (intercalate) trace s x = Lib.Trace.trace s @@ -67,31 +68,52 @@ matchMergeExpr body -- else return $ Case exprs branches _ -> return body +-- Takes a set of branches, and transforms them by merging branches that have some shared superstructure. +-- Returns the new branch structure and whether any changes were made mergeBranches :: [Branch] -> Unique ([Branch], Bool) +-- No branches, no changes mergeBranches [] = return ([], False) +-- Branch with constructor pattern, try to merge it with the rest mergeBranches branches@(b@(Branch [pat@PatCon{patConPatterns=ps}] _): rst) = do splitted <- splitBranchConstructors b rst case splitted of + -- Single branch, return itself unchanged ([b], [], err, pat') -> return ([b], False) + -- Single branch shares structure, rest do not, merge the rest and append ([b], rst, err, pat') -> do (rest, v) <- mergeBranches rst return (b:rest, v) + -- Multiple branches share structure (bs, rst, err, pat') -> do + trace ("mergeBranches:\n" ++ intercalate "\n" (map (show . branchPatterns) bs) ++ "\n with common superstructure:\n" ++ show pat' ++ "\n\n") $ return () let - vars' = collectVars pat' - varsMatch = [Var tn InfoNone | tn <- vars'] - (rest, v) <- mergeBranches rst - (newBranches, innerV) <- mergeBranches $ map (stripOuterConstructors pat') bs ++ maybeToList (fmap (generalErrorBranch) err) -- Add back error to sub branches - return (Branch [pat'] [Guard exprTrue (Case varsMatch newBranches)] : rest, True) + vars' = collectPatsVars pat' -- Collect the variables introduced by the shared structure + varsMatch = [Var tn InfoNone | tn <- vars'] -- Create expressions for those vars + -- Get rid of the common superstructure from the branches that share superstructure + -- Also add the implicit error branch if it exists + subBranches = map (stripOuterConstructors pat') bs ++ maybeToList err + (newSubBranches, innerV) <- mergeBranches subBranches + (rest, v) <- mergeBranches rst -- Merge the branches that do not share structure with the current set + -- Replace the set of common branches, with a single branch that matches on the shared superstructure, and delegates + -- to another case expression to distinguish between the different substructures + return (Branch pat' [Guard exprTrue (Case varsMatch newSubBranches)] : rest, True) +-- Default (non-constructor patterns), just merge the rest, and add the first branch back mergeBranches (b:bs) = mergeBranches bs >>= (\(bs', v) -> return (b:bs', v)) +-- TODO: Add support for var patterns +-- TODO: Add support for branches with multiple patterns + +-- Collects the vars from a pattern in a canonical order (instead of fvs which uses sets) +collectPatsVars :: [Pattern] -> [TName] +collectPatsVars = concatMap collectVars collectVars :: Pattern -> [TName] collectVars p = case p of - PatVar name _ -> [name] + PatVar name PatWild -> [name] + PatVar name pt -> collectVars pt PatCon{patConPatterns = ps} -> concatMap collectVars ps _ -> [] @@ -101,12 +123,12 @@ collectVars p -- - a possible (implicit error) branch found -- - and the pattern that unifies the matched branches -- Greedily in order processing, The first branch is the branch under consideration and the others are the next parameter -splitBranchConstructors :: Branch -> [Branch] -> Unique ([Branch], [Branch], Maybe Branch, Pattern) -splitBranchConstructors b@(Branch [p] _) branches = +splitBranchConstructors :: Branch -> [Branch] -> Unique ([Branch], [Branch], Maybe Branch, [Pattern]) +splitBranchConstructors b@(Branch ps _) branches = case branches of -- Only one branch, it matches it's own pattern - [] -> return ([b], [], if isErrorBranch b then Just b else Nothing, p) - b'@(Branch [p'] _):bs -> + [] -> return ([b], [], if isErrorBranch b then Just b else Nothing, ps) + b'@(Branch ps' _):bs -> do -- First do the rest other than b' (bs', bs2', e, accP) <- splitBranchConstructors b bs @@ -116,19 +138,30 @@ splitBranchConstructors b@(Branch [p] _) branches = (_, b') | isErrorBranch b' -> Just b' -- b' is the error branch _ -> Nothing -- no error branch -- Acumulated pattern and p' - patNew <- patternsMatch accP p' - case patNew of + patNew <- zipWithM patternsMatch accP ps' + if not $ isSimpleMatches patNew then -- Restrict the pattern to the smallest that matches multiple branches -- Add the new branch to the list of branches that match partially - Just p | not (isPatWild p) -> return (bs' ++ [b'], bs2', newError, p) + trace ("splitConstructors:\n" ++ show accP ++ "\nand\n" ++ show ps' ++ "\n have common superstructure:\n" ++ show patNew ++ "\n\n") + $ return (bs' ++ [b'], bs2', newError, patNew) -- Didn't match the current branch, keep the old pattern -- Add the new branch to the list of branches that don't match any subpattern - _ -> return (bs', b':bs2', newError, accP) + else return (bs', b':bs2', newError, accP) isPatWild :: Pattern -> Bool isPatWild PatWild = True isPatWild _ = False +isSimpleMatches :: [Pattern] -> Bool +isSimpleMatches = all isSimpleMatch + +isSimpleMatch :: Pattern -> Bool +isSimpleMatch p = + case p of + PatVar _ p -> isSimpleMatch p + PatWild -> True + _ -> False + -- Checks to see if the branch is an error branch isErrorBranch:: Branch -> Bool isErrorBranch (Branch _ [Guard _ (App (TypeApp (Var name _) _) _)]) = getName name == namePatternMatchError @@ -138,66 +171,56 @@ generalErrorBranch:: Branch -> Branch generalErrorBranch b@(Branch p g) | isErrorBranch b = Branch [PatWild] g generalErrorBranch b = b --- Returns largest common subpattern, with variables added where needed -patternsMatch :: Pattern -> Pattern -> Unique (Maybe Pattern) +-- Returns largest common pattern superstructure, with variables added where needed +patternsMatch :: Pattern -> Pattern -> Unique Pattern patternsMatch p p' = case (p, p') of - (PatLit l1, PatLit l2) -> if l1 == l2 then return (Just p) else newVarName >>= \name -> return $ Just $ PatVar (TName name (typeOf l1)) PatWild - (PatVar tn1 v1, PatVar tn2 v2) | tn1 == tn2 -> do + (PatLit l1, PatLit l2) -> + if l1 == l2 then return p -- Literals that match, just match the literal + else do -- Match a variable of the literal's type + name <- newVarName + return $ PatVar (TName name (typeOf l1)) PatWild + (PatVar tn1 v1, PatVar tn2 v2) | tn1 == tn2 -> do + -- Same pattern variable, reuse the variable name, but find common substructure sub <- patternsMatch v1 v2 - case sub of - Nothing -> return Nothing - Just sub -> return $ Just $ PatVar tn1 sub + return $ PatVar tn1 sub (PatVar tn1 v1, PatVar tn2 v2) -> do + -- Variables that don't match name, but (should match types because of type checking) + -- Create a common name to match for name <- newVarName sub <- patternsMatch v1 v2 - case sub of - Nothing -> return Nothing - Just sub -> return $ Just $ PatVar (TName name (typeOf tn1)) sub - (PatWild, PatWild) -> return $ Just PatWild + return $ PatVar (TName name (typeOf tn1)) sub + (PatWild, PatWild) -> return PatWild -- Wilds match trivially (PatCon name1 patterns1 cr targs1 exists1 res1 ci sk, PatCon name2 patterns2 _ targs2 exists2 res2 _ _) -> - if + if -- Same constructor (name, and types) -- types should match due to type checking, but names could differ name1 == name2 && targs1 == targs2 && exists1 == exists2 && res1 == res2 - then do - subs <- mapM orVar (zip3 patterns1 patterns2 targs1) - return $ Just $ PatCon name1 subs cr targs1 exists1 res1 ci sk - else return Nothing + then do + -- Same constructor, match substructure + subs <- zipWithM patternsMatch patterns1 patterns2 + return $ PatCon name1 subs cr targs1 exists1 res1 ci sk + else do + name <- newVarName + return $ PatVar (TName name res1) PatWild -- Different constructors, no match (PatVar tn pat, _) -> do - name <- newVarName - return $ Just $ PatVar (TName name (typeOf tn)) PatWild + sub <- patternsMatch pat p' + return $ PatVar tn sub (_, PatVar tn pat) -> do - name <- newVarName - return $ Just $ PatVar (TName name (typeOf tn)) PatWild - (_, PatWild) -> return $ Just PatWild - (PatWild, _) -> return $ Just PatWild - (_, _) -> return Nothing - where - orVar (p1,p2,t) = do - v <- patternsMatch p1 p2 - case v of - Nothing -> do - name <- newVarName - return $ PatVar (TName name t) PatWild - Just x -> return x - newVarName = uniqueId "case" >>= (\id -> return $ newHiddenName ("case" ++ show id)) + sub <- patternsMatch p pat + return $ PatVar tn sub + (_, PatWild) -> return PatWild + (PatWild, _) -> return PatWild + (_, _) -> failure $ "patternsMatch: " ++ show p ++ " " ++ show p' ++ " " + where newVarName = uniqueId "case" >>= (\id -> return $ newHiddenName ("case" ++ show id)) -- Strip the outer constructors and propagate variable substitution into branch expressions -stripOuterConstructors :: Pattern -> Branch -> Branch -stripOuterConstructors template (Branch [pt] exprs) - = -- trace ("Using template\n" ++ show template ++ "\nand outer subpattern from\n" ++ show pt ++ "\ngot:\n" ++ show (patNew, replaceMap) ++ "\n") $ - Branch (fromMaybe [PatWild] patNew) $ map replaceInGuard exprs +stripOuterConstructors :: [Pattern] -> Branch -> Branch +stripOuterConstructors templates (Branch pts exprs) + = trace ("Using template\n" ++ show templates ++ "\nand outer subpattern from\n" ++ show pts ++ "\ngot:\n" ++ show (patNew, replaceMap) ++ "\n") $ + Branch (concatMap (fromMaybe [PatWild]) patNew) $ map replaceInGuard exprs where - replaceInPattern :: Pattern -> Pattern - replaceInPattern p - = case p of - PatVar name _ -> case lookup name replaceMap of - Just (Var name info) -> PatVar name PatWild - _ -> p - PatCon name patterns cr targs exists res ci sk -> PatCon name (map replaceInPattern patterns) cr targs exists res ci sk - _ -> p replaceInGuard (Guard tst expr) = Guard (rewriteBottomUp replaceInExpr tst) (rewriteBottomUp replaceInExpr expr) replaceInExpr :: Expr -> Expr @@ -207,23 +230,28 @@ stripOuterConstructors template (Branch [pt] exprs) Just (Var name info) -> Var name info _ -> e e' -> e' - (patNew, replaceMap) = getReplaceMap template pt + (patNew, replaceMaps) = unzip $ zipWith getReplaceMap templates pts + replaceMap = concat replaceMaps -- Get the new pattern that differs from the old pattern and the subsitution map getReplaceMap :: Pattern -> Pattern -> (Maybe [Pattern], [(TName, Expr)]) getReplaceMap template p' = case (template, p') of (PatLit l1, PatLit l2) -> (Nothing, []) + (PatVar tn1 v1, PatVar tn2 v2) | tn1 == tn2 -> + let (pat', rp) = getReplaceMap v1 v2 + in (pat', rp) (PatVar tn1 v1, PatVar tn2 v2) -> let (pat', rp) = getReplaceMap v1 v2 in (pat', (tn2, Var tn1 InfoNone):rp) (PatWild, PatWild) -> (Nothing, []) (PatCon name1 patterns1 cr targs1 exists1 res1 ci _, PatCon name2 patterns2 _ targs2 exists2 res2 _ sk) -> - let res = map (\(p1,p2) -> getReplaceMap p1 p2) (zip patterns1 patterns2) + let res = zipWith getReplaceMap patterns1 patterns2 (patterns', replaceMaps) = unzip res replaceMap = concat replaceMaps - in (Just (concatMap (\m -> (fromMaybe [PatWild] m)) patterns'), replaceMap) + in (Just (concatMap (fromMaybe []) patterns'), replaceMap) + (PatVar tn PatWild, pat2) -> (Just [pat2], []) (PatVar tn pat, pat2) -> getReplaceMap pat pat2 (pat, PatVar tn pat2) -> getReplaceMap pat pat2 (PatWild, pat2) -> (Just [pat2], []) - _ -> failure $ "getReplaceMap: " ++ show template ++ " " ++ show p' + _ -> failure $ "\ngetReplaceMap:\n" ++ show template ++ "\n:" ++ show p' ++ "\n"