Skip to content

Commit

Permalink
Merge branch 'dg/api_make_parameter_public' of github.com:halide/Hali…
Browse files Browse the repository at this point in the history
…de into dg/api_make_parameter_public
  • Loading branch information
Derek Gerstmann committed Sep 15, 2023
2 parents b2f3375 + 9452ad9 commit 00ba0b2
Show file tree
Hide file tree
Showing 21 changed files with 642 additions and 91 deletions.
32 changes: 19 additions & 13 deletions apps/simd_op_check/driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ struct filter {
const char *name;
int (*fn)(halide_buffer_t *, // float32
halide_buffer_t *, // float64
halide_buffer_t *, // float16
halide_buffer_t *, // bfloat16
halide_buffer_t *, // int8
halide_buffer_t *, // uint8
halide_buffer_t *, // int16
Expand All @@ -33,7 +35,7 @@ extern "C" void halide_print(void *, const char *msg) {
}

template<typename T>
halide_buffer_t make_buffer(int w, int h) {
halide_buffer_t make_buffer(int w, int h, halide_type_t halide_type) {
T *mem = NULL;
#ifdef __APPLE__
// memalign() isn't present on OSX, but posix_memalign is
Expand All @@ -53,7 +55,7 @@ halide_buffer_t make_buffer(int w, int h) {
buf.host = (uint8_t *)mem;
buf.dim[0].extent = w;
buf.dim[1].extent = h;
buf.type = halide_type_of<T>();
buf.type = halide_type;
buf.dim[0].stride = 1;
buf.dim[1].stride = w;
buf.dim[0].min = -128;
Expand All @@ -73,18 +75,20 @@ int main(int argc, char **argv) {
bool error = false;
// Make some input buffers
halide_buffer_t bufs[] = {
make_buffer<float>(W, H),
make_buffer<double>(W, H),
make_buffer<int8_t>(W, H),
make_buffer<uint8_t>(W, H),
make_buffer<int16_t>(W, H),
make_buffer<uint16_t>(W, H),
make_buffer<int32_t>(W, H),
make_buffer<uint32_t>(W, H),
make_buffer<int64_t>(W, H),
make_buffer<uint64_t>(W, H)};
make_buffer<float>(W, H, halide_type_of<float>()),
make_buffer<double>(W, H, halide_type_of<double>()),
make_buffer<uint16_t>(W, H, halide_type_t(halide_type_float, 16)),
make_buffer<uint16_t>(W, H, halide_type_t(halide_type_bfloat, 16)),
make_buffer<int8_t>(W, H, halide_type_of<int8_t>()),
make_buffer<uint8_t>(W, H, halide_type_of<uint8_t>()),
make_buffer<int16_t>(W, H, halide_type_of<int16_t>()),
make_buffer<uint16_t>(W, H, halide_type_of<uint16_t>()),
make_buffer<int32_t>(W, H, halide_type_of<int32_t>()),
make_buffer<uint32_t>(W, H, halide_type_of<uint32_t>()),
make_buffer<int64_t>(W, H, halide_type_of<int64_t>()),
make_buffer<uint64_t>(W, H, halide_type_of<uint64_t>())};

halide_buffer_t out = make_buffer<double>(1, 1);
halide_buffer_t out = make_buffer<double>(1, 1, halide_type_of<double>());

double *out_value = (double *)(out.host);

Expand All @@ -101,6 +105,8 @@ int main(int argc, char **argv) {
bufs + 7,
bufs + 8,
bufs + 9,
bufs + 10,
bufs + 11,
&out);
if (*out_value) {
printf("Error: %f\n", *out_value);
Expand Down
1 change: 1 addition & 0 deletions python_bindings/src/halide/halide_/PyEnums.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ void define_enums(py::module &m) {
.value("AVX512_KNL", Target::Feature::AVX512_KNL)
.value("AVX512_Skylake", Target::Feature::AVX512_Skylake)
.value("AVX512_Cannonlake", Target::Feature::AVX512_Cannonlake)
.value("AVX512_Zen4", Target::Feature::AVX512_Zen4)
.value("AVX512_SapphireRapids", Target::Feature::AVX512_SapphireRapids)
.value("TraceLoads", Target::Feature::TraceLoads)
.value("TraceStores", Target::Feature::TraceStores)
Expand Down
35 changes: 34 additions & 1 deletion src/CSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,39 @@ Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) {

debug(4) << "After removing lets: " << e << "\n";

// CSE is run on unsanitized Exprs from the user, and may contain Vars with
// the same name as the temporaries we intend to introduce. Find any such
// Vars so that we know not to use those names.
class UniqueNameProvider : public IRGraphVisitor {
using IRGraphVisitor::visit;

const char prefix = 't'; // Annoyingly, this can't be static because this is a local class.

void visit(const Variable *op) override {
// It would be legal to just add all names found to the tracked set,
// but because we know the form of the new names we're going to
// introduce, we can save some time by only adding names that could
// plausibly collide. In the vast majority of cases, this check will
// result in the set being empty.
if (op->name.size() > 1 &&
op->name[0] == prefix &&
isdigit(op->name[1])) {
vars.insert(op->name);
}
}
std::set<string> vars;

public:
string make_unique_name() {
string name;
do {
name = unique_name(prefix);
} while (vars.count(name));
return name;
}
} namer;
e.accept(&namer);

GVN gvn;
e = gvn.mutate(e);

Expand All @@ -298,7 +331,7 @@ Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) {
for (size_t i = 0; i < gvn.entries.size(); i++) {
const auto &e = gvn.entries[i];
if (e->use_count > 1) {
string name = unique_name('t');
string name = namer.make_unique_name();
lets.emplace_back(name, e->expr);
// Point references to this expr to the variable instead.
replacements[e->expr] = Variable::make(e->expr.type(), name);
Expand Down
80 changes: 48 additions & 32 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ namespace {
// oldest feature flag that supports an instruction.
Target complete_x86_target(Target t) {
if (t.has_feature(Target::AVX512_SapphireRapids)) {
t.set_feature(Target::AVX512_Zen4);
}
if (t.has_feature(Target::AVX512_Zen4)) {
t.set_feature(Target::AVX512_Cannonlake);
}
if (t.has_feature(Target::AVX512_Cannonlake)) {
Expand Down Expand Up @@ -67,8 +70,6 @@ class CodeGen_X86 : public CodeGen_Posix {

int vector_lanes_for_slice(const Type &t) const;

llvm::Type *llvm_type_of(const Type &t) const override;

using CodeGen_Posix::visit;

void init_module() override;
Expand Down Expand Up @@ -210,12 +211,19 @@ const x86Intrinsic intrinsic_defs[] = {
{"llvm.x86.sse2.pmulhu.w", UInt(16, 8), "pmulh", {UInt(16, 8), UInt(16, 8)}},
{"llvm.x86.ssse3.pmul.hr.sw.128", Int(16, 8), "pmulhrs", {Int(16, 8), Int(16, 8)}, Target::SSE41},

// As of LLVM main September 5 2023, LLVM only has partial handling of
// bfloat16. The below rules will match fine for simple examples, but bfloat
// conversion will get folded through any nearby shuffles and cause
// unimplemented errors in llvm's x86 instruction selection for the shuffle
// node. Disabling them for now. See https://github.com/halide/Halide/issues/7219
/*
// Convert FP32 to BF16
{"vcvtne2ps2bf16x32", BFloat(16, 32), "f32_to_bf16", {Float(32, 32)}, Target::AVX512_SapphireRapids},
{"llvm.x86.avx512bf16.cvtneps2bf16.512", BFloat(16, 16), "f32_to_bf16", {Float(32, 16)}, Target::AVX512_SapphireRapids},
{"llvm.x86.avx512bf16.cvtneps2bf16.256", BFloat(16, 8), "f32_to_bf16", {Float(32, 8)}, Target::AVX512_SapphireRapids},
{"vcvtne2ps2bf16x32", BFloat(16, 32), "f32_to_bf16", {Float(32, 32)}, Target::AVX512_Zen4},
{"llvm.x86.avx512bf16.cvtneps2bf16.512", BFloat(16, 16), "f32_to_bf16", {Float(32, 16)}, Target::AVX512_Zen4},
{"llvm.x86.avx512bf16.cvtneps2bf16.256", BFloat(16, 8), "f32_to_bf16", {Float(32, 8)}, Target::AVX512_Zen4},
// LLVM does not provide an unmasked 128bit cvtneps2bf16 intrinsic, so provide a wrapper around the masked version.
{"vcvtneps2bf16x4", BFloat(16, 4), "f32_to_bf16", {Float(32, 4)}, Target::AVX512_SapphireRapids},
{"vcvtneps2bf16x4", BFloat(16, 4), "f32_to_bf16", {Float(32, 4)}, Target::AVX512_Zen4},
*/

// 2-way dot products
{"llvm.x86.avx2.pmadd.ub.sw", Int(16, 16), "saturating_dot_product", {UInt(8, 32), Int(8, 32)}, Target::AVX2},
Expand All @@ -242,23 +250,23 @@ const x86Intrinsic intrinsic_defs[] = {

// 4-way dot product vector reduction
// The LLVM intrinsics combine the bf16 pairs into i32, so provide a wrapper to correctly call the intrinsic.
{"dpbf16psx16", Float(32, 16), "dot_product", {Float(32, 16), BFloat(16, 32), BFloat(16, 32)}, Target::AVX512_SapphireRapids},
{"dpbf16psx16", Float(32, 16), "dot_product", {Float(32, 16), BFloat(16, 32), BFloat(16, 32)}, Target::AVX512_Zen4},
{"dpbf16psx8", Float(32, 8), "dot_product", {Float(32, 8), BFloat(16, 16), BFloat(16, 16)}, Target::AVX512_SapphireRapids},
{"dpbf16psx4", Float(32, 4), "dot_product", {Float(32, 4), BFloat(16, 8), BFloat(16, 8)}, Target::AVX512_SapphireRapids},

{"dpbusdx16", Int(32, 16), "dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_SapphireRapids},
{"dpbusdx16", Int(32, 16), "dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_Zen4},
{"dpbusdx8", Int(32, 8), "dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_SapphireRapids},
{"dpbusdx4", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_SapphireRapids},

{"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids},
{"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_Zen4},
{"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids},
{"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids},

{"dpbusdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_SapphireRapids},
{"dpbusdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_Zen4},
{"dpbusdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_SapphireRapids},
{"dpbusdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_SapphireRapids},

{"dpwssdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_SapphireRapids},
{"dpwssdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_Zen4},
{"dpwssdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_SapphireRapids},
{"dpwssdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_SapphireRapids},

Expand Down Expand Up @@ -488,9 +496,25 @@ void CodeGen_X86::visit(const Select *op) {
}

void CodeGen_X86::visit(const Cast *op) {
Type src = op->value.type();
Type dst = op->type;

if (target.has_feature(Target::F16C) &&
dst.code() == Type::Float &&
src.code() == Type::Float &&
(dst.bits() == 16 || src.bits() == 16)) {
// Node we use code() == Type::Float instead of is_float(), because we
// don't want to catch bfloat casts.

// This target doesn't support full float16 arithmetic, but it *does*
// support float16 casts, so we emit a vanilla LLVM cast node.
value = codegen(op->value);
value = builder->CreateFPCast(value, llvm_type_of(dst));
return;
}

if (!op->type.is_vector()) {
// We only have peephole optimizations for vectors in here.
if (!dst.is_vector()) {
// We only have peephole optimizations for vectors after this point.
CodeGen_Posix::visit(op);
return;
}
Expand All @@ -513,20 +537,20 @@ void CodeGen_X86::visit(const Cast *op) {
vector<Expr> matches;
for (const Pattern &p : patterns) {
if (expr_match(p.pattern, op, matches)) {
value = call_overloaded_intrin(op->type, p.intrin, matches);
value = call_overloaded_intrin(dst, p.intrin, matches);
if (value) {
return;
}
}
}

if (const Call *mul = Call::as_intrinsic(op->value, {Call::widening_mul})) {
if (op->value.type().bits() < op->type.bits() && op->type.bits() <= 32) {
if (src.bits() < dst.bits() && dst.bits() <= 32) {
// LLVM/x86 really doesn't like 8 -> 16 bit multiplication. If we're
// widening to 32-bits after a widening multiply, LLVM prefers to see a
// widening multiply directly to 32-bits. This may result in extra
// casts, so simplify to remove them.
value = codegen(simplify(Mul::make(Cast::make(op->type, mul->args[0]), Cast::make(op->type, mul->args[1]))));
value = codegen(simplify(Mul::make(Cast::make(dst, mul->args[0]), Cast::make(dst, mul->args[1]))));
return;
}
}
Expand Down Expand Up @@ -871,6 +895,8 @@ string CodeGen_X86::mcpu_target() const {
// The CPU choice here *WILL* affect -mattrs!
if (target.has_feature(Target::AVX512_SapphireRapids)) {
return "sapphirerapids";
} else if (target.has_feature(Target::AVX512_Zen4)) {
return "znver4";
} else if (target.has_feature(Target::AVX512_Cannonlake)) {
return "cannonlake";
} else if (target.has_feature(Target::AVX512_Skylake)) {
Expand Down Expand Up @@ -917,6 +943,8 @@ string CodeGen_X86::mcpu_tune() const {
return "znver2";
case Target::Processor::ZnVer3:
return "znver3";
case Target::Processor::ZnVer4:
return "znver4";

case Target::Processor::ProcessorGeneric:
break;
Expand Down Expand Up @@ -958,8 +986,11 @@ string CodeGen_X86::mattrs() const {
if (target.has_feature(Target::AVX512_Cannonlake)) {
features += ",+avx512ifma,+avx512vbmi";
}
if (target.has_feature(Target::AVX512_Zen4)) {
features += ",+avx512bf16,+avx512vnni,+avx512bitalg,+avx512vbmi2";
}
if (target.has_feature(Target::AVX512_SapphireRapids)) {
features += ",+avx512bf16,+avx512vnni,+amx-int8,+amx-bf16";
features += ",+avxvnni,+amx-int8,+amx-bf16";
}
}
return features;
Expand Down Expand Up @@ -997,21 +1028,6 @@ int CodeGen_X86::vector_lanes_for_slice(const Type &t) const {
return slice_bits / t.bits();
}

llvm::Type *CodeGen_X86::llvm_type_of(const Type &t) const {
if (t.is_float() && t.bits() < 32) {
// LLVM as of August 2019 has all sorts of issues in the x86
// backend for half types. It injects expensive calls to
// convert between float and half for seemingly no reason
// (e.g. to do a select), and bitcasting to int16 doesn't
// help, because it simplifies away the bitcast for you.
// See: https://bugs.llvm.org/show_bug.cgi?id=43065
// and: https://github.com/halide/Halide/issues/4166
return llvm_type_of(t.with_code(halide_type_uint));
} else {
return CodeGen_Posix::llvm_type_of(t);
}
}

} // namespace

std::unique_ptr<CodeGen_Posix> new_CodeGen_X86(const Target &target) {
Expand Down
14 changes: 12 additions & 2 deletions src/Definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,23 @@ class Definition {
* definition. */
void mutate(IRMutator *);

/** Get the default (no-specialization) arguments (left-hand-side) of the definition */
/** Get the default (no-specialization) arguments (left-hand-side) of the definition.
*
* Warning: Any Vars in the Exprs are not qualified with the Func name, so
* the Exprs may contain names which collide with names provided by
* unique_name.
*/
// @{
const std::vector<Expr> &args() const;
std::vector<Expr> &args();
// @}

/** Get the default (no-specialization) right-hand-side of the definition */
/** Get the default (no-specialization) right-hand-side of the definition.
*
* Warning: Any Vars in the Exprs are not qualified with the Func name, so
* the Exprs may contain names which collide with names provided by
* unique_name.
*/
// @{
const std::vector<Expr> &values() const;
std::vector<Expr> &values();
Expand Down
28 changes: 28 additions & 0 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,34 @@ void Stage::split(const string &old, const string &outer, const string &inner, c
<< "Use TailStrategy::GuardWithIf instead.";
}

bool predicate_loads_ok = !exact;
if (predicate_loads_ok && tail == TailStrategy::PredicateLoads) {
// If it's the outermost split in this dimension, PredicateLoads
// is OK. Otherwise we can't prove it's safe.
std::set<string> inner_vars;
for (const Split &s : definition.schedule().splits()) {
if (s.is_split()) {
inner_vars.insert(s.inner);
if (inner_vars.count(s.old_var)) {
inner_vars.insert(s.outer);
}
} else if (s.is_rename() || s.is_purify()) {
if (inner_vars.count(s.old_var)) {
inner_vars.insert(s.outer);
}
} else if (s.is_fuse()) {
if (inner_vars.count(s.inner) || inner_vars.count(s.outer)) {
inner_vars.insert(s.old_var);
}
}
}
predicate_loads_ok = !inner_vars.count(old_name);
user_assert(predicate_loads_ok || tail != TailStrategy::PredicateLoads)
<< "Can't use TailStrategy::PredicateLoads for splitting " << old_name
<< " in the definition of " << name() << ". "
<< "PredicateLoads may not be used to split a Var stemming from the inner Var of a prior split.";
}

if (tail == TailStrategy::Auto) {
// Select a tail strategy
if (exact) {
Expand Down
7 changes: 6 additions & 1 deletion src/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,12 @@ class Function {
int required_dimensions() const;

/** Get the right-hand-side of the pure definition. Returns an
* empty vector if there is no pure definition. */
* empty vector if there is no pure definition.
*
* Warning: Any Vars in the Exprs are not qualified with the Func name, so
* the Exprs may contain names which collide with names provided by
* unique_name.
*/
const std::vector<Expr> &values() const;

/** Does this function have a pure definition? */
Expand Down
6 changes: 5 additions & 1 deletion src/Simplify_Stmts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ Stmt Simplify::visit(const IfThenElse *op) {
if (else_unreachable) {
return then_case;
} else if (then_unreachable) {
return else_case;
if (else_case.defined()) {
return else_case;
} else {
return Evaluate::make(0);
}
}

if (is_no_op(else_case)) {
Expand Down
Loading

0 comments on commit 00ba0b2

Please sign in to comment.