Skip to content

Commit

Permalink
[PEx] Correct corner cases with while, receive, and after statements
Browse files Browse the repository at this point in the history
  • Loading branch information
aman-goel committed May 14, 2024
1 parent dcf2b44 commit 4637281
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,8 @@ private bool WriteStmt(Function function, CompilationContext context, StringWrit
break;
case ReceiveSplitStmt splitStmt:
context.WriteLine(output, $"{CompilationContext.CurrentMachine}.blockUntil(\"{context.GetContinuationName(splitStmt.Cont)}\");");
context.Write(output, "return;");
exited = true;
break;
default:
throw new NotImplementedException($"Statement type '{stmt.GetType().Name}' is not supported, found in {function.Name}");
Expand Down
24 changes: 15 additions & 9 deletions Src/PCompiler/CompilerCore/Backend/PExplicit/TransformASTPass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ static private void GenerateInline(Function caller, Function callee, IReadOnlyLi
callNum++;
}

static private List<IPStmt> ReplaceReturn(IReadOnlyList<IPStmt> body, IPExpr location)
private static List<IPStmt> ReplaceReturn(IReadOnlyList<IPStmt> body, IPExpr location)
{
var newBody = new List<IPStmt>();
foreach (var stmt in body)
Expand Down Expand Up @@ -468,7 +468,7 @@ static private IPExpr ReplaceVars(IPExpr expr, Dictionary<Variable,Variable> var
}
}

static private Function TransformFunction(Function function, Machine machine)
private static Function TransformFunction(Function function, Machine machine)
{
if (function.CanReceive != true) {
return function;
Expand All @@ -492,7 +492,7 @@ static private Function TransformFunction(Function function, Machine machine)
return transformedFunction;
}

static private IPStmt ReplaceBreaks(IPStmt stmt, List<IPStmt> afterStmts)
static private IPStmt InlineAfterAndReplaceBreaks(IPStmt stmt, List<IPStmt> afterStmts)
{
if (stmt == null) return null;
var statements = new List<IPStmt>();
Expand All @@ -501,11 +501,11 @@ static private IPStmt ReplaceBreaks(IPStmt stmt, List<IPStmt> afterStmts)
case CompoundStmt compoundStmt:
foreach (var inner in compoundStmt.Statements)
{
statements.Add(ReplaceBreaks(inner, afterStmts));
statements.Add(InlineAfterAndReplaceBreaks(inner, afterStmts));
}
return new CompoundStmt(compoundStmt.SourceLocation, statements);
case IfStmt ifStmt:
return new IfStmt(ifStmt.SourceLocation, ifStmt.Condition, ReplaceBreaks(ifStmt.ThenBranch, afterStmts), ReplaceBreaks(ifStmt.ElseBranch, afterStmts));
return new IfStmt(ifStmt.SourceLocation, ifStmt.Condition, InlineAfterAndReplaceBreaks(ifStmt.ThenBranch, afterStmts), InlineAfterAndReplaceBreaks(ifStmt.ElseBranch, afterStmts));
case ReceiveStmt receiveStmt:
var cases = new Dictionary<PEvent, Function>();
foreach(var entry in receiveStmt.Cases)
Expand All @@ -521,7 +521,7 @@ static private IPStmt ReplaceBreaks(IPStmt stmt, List<IPStmt> afterStmts)
foreach (var param in entry.Value.Signature.Parameters) replacement.Signature.Parameters.Add(param);
replacement.Signature.ReturnType = entry.Value.Signature.ReturnType;
foreach (var callee in entry.Value.Callees) replacement.AddCallee(callee);
replacement.Body = (CompoundStmt) ReplaceBreaks(entry.Value.Body, afterStmts);
replacement.Body = (CompoundStmt) InlineAfterAndReplaceBreaks(entry.Value.Body, afterStmts);
cases.Add(entry.Key, replacement);
}
return new ReceiveStmt(receiveStmt.SourceLocation, cases);
Expand All @@ -542,7 +542,7 @@ static private IPStmt ReplaceBreaks(IPStmt stmt, List<IPStmt> afterStmts)
}
}

static private bool CanReceive(IPStmt stmt)
private static bool CanReceive(IPStmt stmt)
{
if (stmt == null) return false;
switch(stmt)
Expand All @@ -567,7 +567,7 @@ static private bool CanReceive(IPStmt stmt)
}
}

static private IPStmt HandleReceives(IPStmt statement, Function function, Machine machine)
private static IPStmt HandleReceives(IPStmt statement, Function function, Machine machine)
{
switch (statement)
{
Expand Down Expand Up @@ -604,6 +604,12 @@ static private IPStmt HandleReceives(IPStmt statement, Function function, Machin
thenStmts = new List<IPStmt>(cond.ThenBranch.Statements);
if (cond.ElseBranch != null)
elseStmts = new List<IPStmt>(cond.ElseBranch.Statements);
if (CanReceive(cond) && (after != null))
{
thenStmts.Add(after);
elseStmts.Add(after);
after = null;
}
IPStmt thenBody = new CompoundStmt(cond.SourceLocation, thenStmts);
IPStmt elseBody = new CompoundStmt(cond.SourceLocation, elseStmts);
thenBody = HandleReceives(thenBody, function, machine);
Expand Down Expand Up @@ -747,7 +753,7 @@ static private IPStmt HandleReceives(IPStmt statement, Function function, Machin
while (bodyEnumerator.MoveNext())
{
var stmt = bodyEnumerator.Current;
var replaceBreak = ReplaceBreaks(stmt, afterStmts);
var replaceBreak = InlineAfterAndReplaceBreaks(stmt, afterStmts);
if (replaceBreak != null) {
loopBody.Add(ReplaceVars(replaceBreak, newVarMap));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ public static void logCurrentDataChoice(PValue<?> choice, int step, int idx) {
public static void logNewState(int step, int idx, Object stateKey, SortedSet<PMachine> machines) {
if (verbosity > 3) {
log.info(String.format(" @%d::%d new state with key %s", step, idx, stateKey));
if (verbosity > 4) {
if (verbosity > 6) {
log.info(String.format(" %s", ComputeHash.getExactString(machines)));
}
}
Expand Down

0 comments on commit 4637281

Please sign in to comment.