From c31e8f7c32893978a6e04a846b904a694f1f0d09 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Wed, 4 Oct 2023 11:12:09 -0700 Subject: [PATCH] Don't deduce unreachability from predicated out of bounds stores (#7874) Fixes #7873 --- src/Simplify_Stmts.cpp | 5 +++-- test/correctness/fuzz_schedule.cpp | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index 6a8e53ccfa73..09b4aed1036d 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -300,12 +300,13 @@ Stmt Simplify::visit(const Store *op) { ExprInfo index_info; Expr index = mutate(op->index, &index_info); - // If the store is fully out of bounds, drop it. + // If the store is fully unconditional and out of bounds, drop it. // This should only occur inside branches that make the store unreachable, // but perhaps the branch was hard to prove constant true or false. This // provides an alternative mechanism to simplify these unreachable stores. string alloc_extent_name = op->name + ".total_extent_bytes"; - if (bounds_and_alignment_info.contains(alloc_extent_name)) { + if (is_const_one(op->predicate) && + bounds_and_alignment_info.contains(alloc_extent_name)) { if (index_info.max_defined && index_info.max < 0) { in_unreachable = true; return Evaluate::make(unreachable()); diff --git a/test/correctness/fuzz_schedule.cpp b/test/correctness/fuzz_schedule.cpp index 4a027c02cb51..60a780a89d6e 100644 --- a/test/correctness/fuzz_schedule.cpp +++ b/test/correctness/fuzz_schedule.cpp @@ -54,6 +54,26 @@ int main(int argc, char **argv) { check_blur_output(buf, correct); } + // https://github.com/halide/Halide/issues/7873 + { + Func input("input"); + Func local_sum("local_sum"); + Func blurry("blurry"); + Var x("x"), y("y"); + RVar yryf; + input(x, y) = 2 * x + 5 * y; + RDom r(-2, 5, -2, 5, "rdom_r"); + local_sum(x, y) = 0; + local_sum(x, y) += input(x + r.x, y + r.y); + blurry(x, y) = cast(local_sum(x, y) / 25); + Var xo, xi; + local_sum.split(x, xo, xi, 4, TailStrategy::PredicateStores); + local_sum.update(0).unscheduled(); + Pipeline p({blurry}); + Buffer buf = p.realize({32, 32}); + check_blur_output(buf, correct); + } + printf("Success!\n"); return 0;