Skip to content

Commit

Permalink
Fix fissioning of if/else (#726)
Browse files Browse the repository at this point in the history
Fix incorrect fission of if/else statements. See test_if_fission for
repro.

(Old behavior: fissioning the body of an if statement causes the orelse
part to be duplicated. Fissioning within the orelse causes an exception)
  • Loading branch information
akeley98 authored Oct 17, 2024
1 parent 36a8ec4 commit cfbdd3c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/exo/LoopIR_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2410,7 +2410,7 @@ def wrapper(body):
if cur_c._node in par_s.body:

def wrapper(body):
return par_s.update(body=body)
return par_s.update(body=body, orelse=[])

ir, fwd_wrap = pre_c._wrap(wrapper, "body")
fwd = _compose(fwd_wrap, fwd)
Expand All @@ -2424,7 +2424,9 @@ def wrapper(body):
assert cur_c._node in par_s.orelse

def wrapper(orelse):
return par_s.update(body=None, orelse=orelse)
return par_s.update(
body=[LoopIR.Pass(par_s.srcinfo)], orelse=orelse
)

ir, fwd_wrap = post_c._wrap(wrapper, "orelse")
fwd = _compose(fwd_wrap, fwd)
Expand Down
42 changes: 42 additions & 0 deletions tests/test_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,48 @@ def foo():
fission(foo, foo.find("x = 0.0").after(), n_lifts=2)


def test_if_fission():
@proc
def before(x: size, y: f32):
if x < 10:
y += 1
y += 2
else:
y += 3
y += 4

@proc
def fission_if(x: size, y: f32):
if x < 10:
y += 1
if x < 10:
y += 2
else:
y += 3
y += 4

@proc
def fission_else(x: size, y: f32):
if x < 10:
y += 1
y += 2
else:
y += 3
if x < 10:
pass
else:
y += 4

test_fission_if = rename(before, "fission_if")
test_fission_if = fission(test_fission_if, test_fission_if.find("y += 1").after())
assert str(fission_if) == str(test_fission_if)
test_fission_else = rename(before, "fission_else")
test_fission_else = fission(
test_fission_else, test_fission_else.find("y += 3").after()
)
assert str(fission_else) == str(test_fission_else)


def test_resize_dim(golden):
@proc
def foo():
Expand Down

0 comments on commit cfbdd3c

Please sign in to comment.