From 781e60de4ed28a713ce10de439bf72fe3ac104b4 Mon Sep 17 00:00:00 2001 From: Tim Whiting Date: Tue, 9 Jul 2024 15:56:17 -0600 Subject: [PATCH] fix guard, clean up some substitutions --- src/Core/MatchMerge.hs | 38 +++++++++++++++++++++++++++----------- test/parc/parc2.kk.out | 2 +- test/parc/parc21.kk.out | 8 ++++---- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/Core/MatchMerge.hs b/src/Core/MatchMerge.hs index f5375d9ed..92eaec134 100644 --- a/src/Core/MatchMerge.hs +++ b/src/Core/MatchMerge.hs @@ -63,15 +63,23 @@ matchMergeExpr body Case exprs branches -> do (branches', changed) <- mergeBranches branches - -- if changed then trace ("matchMergeExpr:\n" ++ show branches ++ "\nrewrote to: \n" ++ show branches' ++ "\n") (return ()) else return () + -- if changed then + -- trace ("matchMergeExpr:\n" ++ show (vcat (map (text . show) branches)) ++ "\nrewrote to: \n" ++ show (vcat (map (text . show) branches')) ++ "\n") (return ()) + -- else return () return $ Case exprs branches' _ -> return body +isTrueGuard :: Guard -> Bool +isTrueGuard guard = isExprTrue (guardExpr guard) + -- Takes a set of branches, and transforms them by merging branches that have some shared superstructure. -- Returns the new branch structure and whether any changes were made mergeBranches :: [Branch] -> Unique ([Branch], Bool) -- No branches, no changes mergeBranches [] = return ([], False) +-- Skip branches with complex guards (in the future we can optimize to merge guards) +mergeBranches (b@(Branch [pat@PatCon{patConPatterns=ps}] guard):bs) | not (all isTrueGuard guard) = + mergeBranches bs >>= (\(bs', v) -> return (b:bs', v)) -- Branch with constructor pattern, try to merge it with the rest mergeBranches branches@(b@(Branch [pat@PatCon{patConPatterns=ps}] _): rst) = -- trace ("mergeBranches:\n" ++ show b ++ "\n\n" ++ show rst ++ "\n\n\n") $ @@ -90,7 +98,7 @@ mergeBranches branches@(b@(Branch [pat@PatCon{patConPatterns=ps}] _): rst) = do -- trace ("mergeBranches:\n" ++ " has error? " ++ show err ++ "\n" ++ intercalate "\n" (map (show . branchPatterns) bs) ++ "\n with common superstructure:\n" ++ show pat' ++ "\n\n") $ return () let - varsMatch = [Var tn InfoNone | tn <- distinguishingVars] -- Create expressions for those vars + varsMatch = [Var tn InfoNone | tn <- distinguishingVars] -- Create expressions for those vars -- Get rid of the common superstructure from the branches that share superstructure -- Also add the implicit error branch if it exists subBranches = map (stripOuterConstructors distinguishingVars pat') bs ++ maybeToList any ++ maybeToList err @@ -254,7 +262,7 @@ patternType p = case p of -- Strip the outer constructors and propagate variable substitution into branch expressions stripOuterConstructors :: [TName] -> [Pattern] -> Branch -> Branch -stripOuterConstructors tns templates (Branch pts exprs) = +stripOuterConstructors discriminatingVars templates (Branch pts exprs) = -- trace ("Using template\n" ++ show templates ++"\n" ++ show pts ++ "\ngot:\n" ++ show (zip tns patNew) ++ "\n" ++ " with variable name mapping " ++ show replaceMap ++ "\n") $ Branch (concatMap (fromMaybe [PatWild]) patNew) $ map replaceInGuard exprs where @@ -267,32 +275,40 @@ stripOuterConstructors tns templates (Branch pts exprs) = Just (Var name info) -> Var name info _ -> e e' -> e' - (patNew, replaceMaps) = unzip $ zipWith (getReplaceMap tns) templates pts + (patNew, replaceMaps) = unzip $ zipWith (getReplaceMap discriminatingVars) templates pts replaceMap = concat replaceMaps -- Get the new pattern that differs from the old pattern and the subsitution map getReplaceMap :: [TName] -> Pattern -> Pattern -> (Maybe [Pattern], [(TName, Expr)]) -getReplaceMap keepVars template p' - = let recur = getReplaceMap keepVars in +getReplaceMap discriminatingVars template p' + = let recur = getReplaceMap discriminatingVars in case (template, p') of (PatLit l1, PatLit l2) -> (Nothing, []) (PatVar tn1 v1, PatVar tn2 v2) | tn1 == tn2 -> let (pat', rp) = recur v1 v2 in case pat' of - Nothing -> if tn1 `notElem` keepVars then (Nothing, rp) else (Just [PatWild], rp) + Nothing -> if tn1 `notElem` discriminatingVars then (Nothing, rp) else (Just [PatWild], rp) Just _ -> (pat', rp) (PatVar tn1 v1, PatVar tn2 v2) -> let (pat', rp) = recur v1 v2 - in case pat' of - Nothing -> if tn1 `notElem` keepVars then (Nothing, rp) else (Just [PatWild], (tn1, Var tn2 InfoNone):rp) - Just _ -> (pat', (tn2, Var tn1 InfoNone):rp) + -- introduce a new variable using the template's name, and map the other name to the template + rp' = (tn2, Var tn1 InfoNone):rp in + case pat' of + Nothing -> -- Differs + -- Doesn't discriminate, but do need to propagate + if tn1 `notElem` discriminatingVars then (Nothing, rp') + -- Introduce just a wild? + else (Just [PatWild], rp') + Just _ -> -- Use the new pattern + (pat', rp') (PatWild, PatWild) -> (Nothing, []) (PatCon name1 patterns1 cr targs1 exists1 res1 ci _, PatCon name2 patterns2 _ targs2 exists2 res2 _ sk) -> let res = zipWith recur patterns1 patterns2 (patterns', replaceMaps) = unzip res replaceMap = concat replaceMaps in (Just (concatMap (fromMaybe []) patterns'), replaceMap) - (PatVar tn PatWild, PatWild) -> (if tn `notElem` keepVars then Nothing else Just [PatWild], []) + (PatVar tn PatWild, PatWild) -> + (if tn `notElem` discriminatingVars then Nothing else Just [PatWild], []) (PatVar tn PatWild, pat2) -> (Just [pat2], []) (PatVar tn pat, pat2) -> recur pat pat2 (PatWild, pat2) -> (Just [pat2], []) diff --git a/test/parc/parc2.kk.out b/test/parc/parc2.kk.out index f755bd947..115df3182 100644 --- a/test/parc/parc2.kk.out +++ b/test/parc/parc2.kk.out @@ -24,6 +24,6 @@ pub fun test : forall (x : list) -> list (std/core/types/Nil() : (list) ) -> x; _ - -> std/core/list/@unroll-xxx((std/core/types/@dup(x)), x); + -> std/core/@unroll-xxx((std/core/types/@dup(x)), x); }; }; \ No newline at end of file diff --git a/test/parc/parc21.kk.out b/test/parc/parc21.kk.out index ba9c15ff3..7e1e2cbe1 100644 --- a/test/parc/parc21.kk.out +++ b/test/parc/parc21.kk.out @@ -25,7 +25,7 @@ pub fun bo : (a : int, ^ b : int) -> int pub fun print-ret : (x : int) -> console int = fn(x: int){ val _ : () - = std/core/console/prints((std/core/int/show((std/core/types/@dup(x))))); + = std/core/prints((std/core/int/show((std/core/types/@dup(x))))); x; }; pub fun test : () -> console int @@ -33,10 +33,10 @@ pub fun test : () -> console int val x@xxx : int = parc/parc21/bo(3, 4); val _ : () - = std/core/console/prints((std/core/int/show(x@xxx))); + = std/core/prints((std/core/int/show(x@xxx))); val _@0 : () - = std/core/console/prints((std/core/int/show(3))); + = std/core/prints((std/core/int/show(3))); val _@1 : () - = std/core/console/prints((std/core/int/show(4))); + = std/core/prints((std/core/int/show(4))); parc/parc21/bo(3, 4); }; \ No newline at end of file