Skip to content

Commit

Permalink
fix guard, clean up some substitutions
Browse files Browse the repository at this point in the history
  • Loading branch information
TimWhiting committed Jul 9, 2024
1 parent eba6d8f commit 781e60d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
38 changes: 27 additions & 11 deletions src/Core/MatchMerge.hs
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,23 @@ matchMergeExpr body
Case exprs branches ->
do
(branches', changed) <- mergeBranches branches
-- if changed then trace ("matchMergeExpr:\n" ++ show branches ++ "\nrewrote to: \n" ++ show branches' ++ "\n") (return ()) else return ()
-- if changed then
-- trace ("matchMergeExpr:\n" ++ show (vcat (map (text . show) branches)) ++ "\nrewrote to: \n" ++ show (vcat (map (text . show) branches')) ++ "\n") (return ())
-- else return ()
return $ Case exprs branches'
_ -> return body

isTrueGuard :: Guard -> Bool
isTrueGuard guard = isExprTrue (guardExpr guard)

-- Takes a set of branches, and transforms them by merging branches that have some shared superstructure.
-- Returns the new branch structure and whether any changes were made
mergeBranches :: [Branch] -> Unique ([Branch], Bool)
-- No branches, no changes
mergeBranches [] = return ([], False)
-- Skip branches with complex guards (in the future we can optimize to merge guards)
mergeBranches (b@(Branch [pat@PatCon{patConPatterns=ps}] guard):bs) | not (all isTrueGuard guard) =
mergeBranches bs >>= (\(bs', v) -> return (b:bs', v))
-- Branch with constructor pattern, try to merge it with the rest
mergeBranches branches@(b@(Branch [pat@PatCon{patConPatterns=ps}] _): rst) =
-- trace ("mergeBranches:\n" ++ show b ++ "\n\n" ++ show rst ++ "\n\n\n") $
Expand All @@ -90,7 +98,7 @@ mergeBranches branches@(b@(Branch [pat@PatCon{patConPatterns=ps}] _): rst) =
do
-- trace ("mergeBranches:\n" ++ " has error? " ++ show err ++ "\n" ++ intercalate "\n" (map (show . branchPatterns) bs) ++ "\n with common superstructure:\n" ++ show pat' ++ "\n\n") $ return ()
let
varsMatch = [Var tn InfoNone | tn <- distinguishingVars] -- Create expressions for those vars
varsMatch = [Var tn InfoNone | tn <- distinguishingVars] -- Create expressions for those vars
-- Get rid of the common superstructure from the branches that share superstructure
-- Also add the implicit error branch if it exists
subBranches = map (stripOuterConstructors distinguishingVars pat') bs ++ maybeToList any ++ maybeToList err
Expand Down Expand Up @@ -254,7 +262,7 @@ patternType p = case p of

-- Strip the outer constructors and propagate variable substitution into branch expressions
stripOuterConstructors :: [TName] -> [Pattern] -> Branch -> Branch
stripOuterConstructors tns templates (Branch pts exprs) =
stripOuterConstructors discriminatingVars templates (Branch pts exprs) =
-- trace ("Using template\n" ++ show templates ++"\n" ++ show pts ++ "\ngot:\n" ++ show (zip tns patNew) ++ "\n" ++ " with variable name mapping " ++ show replaceMap ++ "\n") $
Branch (concatMap (fromMaybe [PatWild]) patNew) $ map replaceInGuard exprs
where
Expand All @@ -267,32 +275,40 @@ stripOuterConstructors tns templates (Branch pts exprs) =
Just (Var name info) -> Var name info
_ -> e
e' -> e'
(patNew, replaceMaps) = unzip $ zipWith (getReplaceMap tns) templates pts
(patNew, replaceMaps) = unzip $ zipWith (getReplaceMap discriminatingVars) templates pts
replaceMap = concat replaceMaps

-- Get the new pattern that differs from the old pattern and the subsitution map
getReplaceMap :: [TName] -> Pattern -> Pattern -> (Maybe [Pattern], [(TName, Expr)])
getReplaceMap keepVars template p'
= let recur = getReplaceMap keepVars in
getReplaceMap discriminatingVars template p'
= let recur = getReplaceMap discriminatingVars in
case (template, p') of
(PatLit l1, PatLit l2) -> (Nothing, [])
(PatVar tn1 v1, PatVar tn2 v2) | tn1 == tn2 ->
let (pat', rp) = recur v1 v2
in case pat' of
Nothing -> if tn1 `notElem` keepVars then (Nothing, rp) else (Just [PatWild], rp)
Nothing -> if tn1 `notElem` discriminatingVars then (Nothing, rp) else (Just [PatWild], rp)
Just _ -> (pat', rp)
(PatVar tn1 v1, PatVar tn2 v2) ->
let (pat', rp) = recur v1 v2
in case pat' of
Nothing -> if tn1 `notElem` keepVars then (Nothing, rp) else (Just [PatWild], (tn1, Var tn2 InfoNone):rp)
Just _ -> (pat', (tn2, Var tn1 InfoNone):rp)
-- introduce a new variable using the template's name, and map the other name to the template
rp' = (tn2, Var tn1 InfoNone):rp in
case pat' of
Nothing -> -- Differs
-- Doesn't discriminate, but do need to propagate
if tn1 `notElem` discriminatingVars then (Nothing, rp')
-- Introduce just a wild?
else (Just [PatWild], rp')
Just _ -> -- Use the new pattern
(pat', rp')
(PatWild, PatWild) -> (Nothing, [])
(PatCon name1 patterns1 cr targs1 exists1 res1 ci _, PatCon name2 patterns2 _ targs2 exists2 res2 _ sk) ->
let res = zipWith recur patterns1 patterns2
(patterns', replaceMaps) = unzip res
replaceMap = concat replaceMaps
in (Just (concatMap (fromMaybe []) patterns'), replaceMap)
(PatVar tn PatWild, PatWild) -> (if tn `notElem` keepVars then Nothing else Just [PatWild], [])
(PatVar tn PatWild, PatWild) ->
(if tn `notElem` discriminatingVars then Nothing else Just [PatWild], [])
(PatVar tn PatWild, pat2) -> (Just [pat2], [])
(PatVar tn pat, pat2) -> recur pat pat2
(PatWild, pat2) -> (Just [pat2], [])
Expand Down
2 changes: 1 addition & 1 deletion test/parc/parc2.kk.out
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ pub fun test : forall<a> (x : list<a>) -> list<a>
(std/core/types/Nil() : (list<a>) )
-> x;
_
-> std/core/list/@unroll-xxx((std/core/types/@dup(x)), x);
-> std/core/@unroll-xxx((std/core/types/@dup(x)), x);
};
};
8 changes: 4 additions & 4 deletions test/parc/parc21.kk.out
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ pub fun bo : (a : int, ^ b : int) -> int
pub fun print-ret : (x : int) -> console int
= fn<console>(x: int){
val _ : ()
= std/core/console/prints((std/core/int/show((std/core/types/@dup(x)))));
= std/core/prints((std/core/int/show((std/core/types/@dup(x)))));
x;
};
pub fun test : () -> console int
= fn<console>(){
val x@xxx : int
= parc/parc21/bo(3, 4);
val _ : ()
= std/core/console/prints((std/core/int/show(x@xxx)));
= std/core/prints((std/core/int/show(x@xxx)));
val _@0 : ()
= std/core/console/prints((std/core/int/show(3)));
= std/core/prints((std/core/int/show(3)));
val _@1 : ()
= std/core/console/prints((std/core/int/show(4)));
= std/core/prints((std/core/int/show(4)));
parc/parc21/bo(3, 4);
};

0 comments on commit 781e60d

Please sign in to comment.