diff --git a/src/exo/LoopIR_scheduling.py b/src/exo/LoopIR_scheduling.py index 0c65580d..317d0f6e 100644 --- a/src/exo/LoopIR_scheduling.py +++ b/src/exo/LoopIR_scheduling.py @@ -2410,7 +2410,7 @@ def wrapper(body): if cur_c._node in par_s.body: def wrapper(body): - return par_s.update(body=body) + return par_s.update(body=body, orelse=[]) ir, fwd_wrap = pre_c._wrap(wrapper, "body") fwd = _compose(fwd_wrap, fwd) @@ -2424,7 +2424,9 @@ def wrapper(body): assert cur_c._node in par_s.orelse def wrapper(orelse): - return par_s.update(body=None, orelse=orelse) + return par_s.update( + body=[LoopIR.Pass(par_s.srcinfo)], orelse=orelse + ) ir, fwd_wrap = post_c._wrap(wrapper, "orelse") fwd = _compose(fwd_wrap, fwd) diff --git a/tests/test_schedules.py b/tests/test_schedules.py index a0e42ebb..f8be4728 100644 --- a/tests/test_schedules.py +++ b/tests/test_schedules.py @@ -481,6 +481,48 @@ def foo(): fission(foo, foo.find("x = 0.0").after(), n_lifts=2) +def test_if_fission(): + @proc + def before(x: size, y: f32): + if x < 10: + y += 1 + y += 2 + else: + y += 3 + y += 4 + + @proc + def fission_if(x: size, y: f32): + if x < 10: + y += 1 + if x < 10: + y += 2 + else: + y += 3 + y += 4 + + @proc + def fission_else(x: size, y: f32): + if x < 10: + y += 1 + y += 2 + else: + y += 3 + if x < 10: + pass + else: + y += 4 + + test_fission_if = rename(before, "fission_if") + test_fission_if = fission(test_fission_if, test_fission_if.find("y += 1").after()) + assert str(fission_if) == str(test_fission_if) + test_fission_else = rename(before, "fission_else") + test_fission_else = fission( + test_fission_else, test_fission_else.find("y += 3").after() + ) + assert str(fission_else) == str(test_fission_else) + + def test_resize_dim(golden): @proc def foo():