diff --git a/src/InferArguments.cpp b/src/InferArguments.cpp index d2f55b1fa781..9fabb47b9ef8 100644 --- a/src/InferArguments.cpp +++ b/src/InferArguments.cpp @@ -5,6 +5,7 @@ #include "ExternFuncArgument.h" #include "Function.h" +#include "IROperator.h" #include "IRVisitor.h" #include "InferArguments.h" @@ -197,7 +198,9 @@ class InferArguments : public IRGraphVisitor { ArgumentEstimates argument_estimates = p.get_argument_estimates(); if (!p.is_buffer()) { - argument_estimates.scalar_def = p.scalar_expr(); + // We don't want to crater here if a scalar param isn't set; + // instead, default to a zero of the right type, like we used to. + argument_estimates.scalar_def = p.has_scalar_expr() ? p.scalar_expr() : make_zero(p.type()); argument_estimates.scalar_min = p.min_value(); argument_estimates.scalar_max = p.max_value(); argument_estimates.scalar_estimate = p.estimate(); diff --git a/src/Parameter.cpp b/src/Parameter.cpp index 1155d5468f30..f3455c699f47 100644 --- a/src/Parameter.cpp +++ b/src/Parameter.cpp @@ -20,6 +20,7 @@ struct ParameterContents { std::vector buffer_constraints; Expr scalar_default, scalar_min, scalar_max, scalar_estimate; const bool is_buffer; + bool data_ever_set = false; MemoryType memory_type = MemoryType::Auto; ParameterContents(Type t, bool b, int d, const std::string &n) @@ -46,6 +47,10 @@ void destroy(const ParameterContents *p) { } // namespace Internal +void Parameter::check_data_ever_set() const { + user_assert(contents->data_ever_set) << "Parameter " << name() << " has never had a scalar value set.\n"; +} + void Parameter::check_defined() const { user_assert(defined()) << "Parameter is undefined\n"; } @@ -123,8 +128,14 @@ bool Parameter::is_buffer() const { return contents->is_buffer; } +bool Parameter::has_scalar_expr() const { + return defined() && !contents->is_buffer && contents->data_ever_set; +} + Expr Parameter::scalar_expr() const { check_is_scalar(); + // Redundant here, since every call to scalar<>() also checks this. + // check_data_ever_set(); const Type t = type(); if (t.is_float()) { switch (t.bits()) { @@ -198,16 +209,31 @@ void Parameter::set_buffer(const Buffer<> &b) { contents->buffer = b; } -void *Parameter::scalar_address() const { +const void *Parameter::read_only_scalar_address() const { check_is_scalar(); + // Code that calls this method is (presumably) going + // to read from the address, so complain if the scalar value + // has never been set. + check_data_ever_set(); return &contents->data; } uint64_t Parameter::scalar_raw_value() const { check_is_scalar(); + check_data_ever_set(); return contents->data; } +void Parameter::set_scalar(const Type &val_type, halide_scalar_value_t val) { + check_type(val_type); + // Setting this to zero isn't strictly necessary, but it does + // mean that the 'unused' bits of the field are never affected by what + // may have previously been there. + contents->data = 0; + memcpy(&contents->data, &val, val_type.bytes()); + contents->data_ever_set = true; +} + /** Tests if this handle is the same as another handle */ bool Parameter::same_as(const Parameter &other) const { return contents.same_as(other.contents); diff --git a/src/Parameter.h b/src/Parameter.h index 712380c2576e..424555ed9253 100644 --- a/src/Parameter.h +++ b/src/Parameter.h @@ -35,6 +35,7 @@ struct ParameterContents; /** A reference-counted handle to a parameter to a halide * pipeline. May be a scalar parameter or a buffer */ class Parameter { + void check_data_ever_set() const; void check_defined() const; void check_is_buffer() const; void check_is_scalar() const; @@ -54,8 +55,9 @@ class Parameter { /** Get the pointer to the current value of the scalar * parameter. For a given parameter, this address will never - * change. Only relevant when jitting. */ - void *scalar_address() const; + * change. Note that this can only be used to *read* from -- it must + * not be written to, so don't cast away the constness. Only relevant when jitting. */ + const void *read_only_scalar_address() const; /** Get the raw data of the current value of the scalar * parameter. Only relevant when serializing. */ @@ -112,27 +114,30 @@ class Parameter { template HALIDE_NO_USER_CODE_INLINE T scalar() const { check_type(type_of()); - return *((const T *)(scalar_address())); + check_data_ever_set(); + return *((const T *)(read_only_scalar_address())); } /** This returns the current value of scalar() - * as an Expr. */ + * as an Expr. If no value has ever been set, it will assert-fail */ Expr scalar_expr() const; + /** This returns true if scalar_expr() would return a valid Expr, + * false if not. */ + bool has_scalar_expr() const; + /** If the parameter is a scalar parameter, set its current * value. Only relevant when jitting */ template HALIDE_NO_USER_CODE_INLINE void set_scalar(T val) { - check_type(type_of()); - *((T *)(scalar_address())) = val; + halide_scalar_value_t sv; + memcpy(&sv, &val, sizeof(val)); + set_scalar(type_of(), sv); } /** If the parameter is a scalar parameter, set its current * value. Only relevant when jitting */ - HALIDE_NO_USER_CODE_INLINE void set_scalar(const Type &val_type, halide_scalar_value_t val) { - check_type(val_type); - memcpy(scalar_address(), &val, val_type.bytes()); - } + void set_scalar(const Type &val_type, halide_scalar_value_t val); /** If the parameter is a buffer parameter, get its currently * bound buffer. Only relevant when jitting */ diff --git a/src/Pipeline.cpp b/src/Pipeline.cpp index d51438e275fe..631033404137 100644 --- a/src/Pipeline.cpp +++ b/src/Pipeline.cpp @@ -843,7 +843,7 @@ void Pipeline::prepare_jit_call_arguments(RealizationArg &outputs, const Target } debug(2) << "JIT input ImageParam argument "; } else { - args_result.store[arg_index++] = p.scalar_address(); + args_result.store[arg_index++] = p.read_only_scalar_address(); debug(2) << "JIT input scalar argument "; } } diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 6e69490657f5..69e0979163b5 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -107,6 +107,7 @@ tests(GROUPS error undefined_pipeline_compile.cpp undefined_pipeline_realize.cpp undefined_rdom_dimension.cpp + uninitialized_param.cpp unknown_target.cpp vector_tile.cpp vectorize_dynamic.cpp diff --git a/test/error/uninitialized_param.cpp b/test/error/uninitialized_param.cpp new file mode 100644 index 000000000000..2fa8d47b2ec3 --- /dev/null +++ b/test/error/uninitialized_param.cpp @@ -0,0 +1,22 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + ImageParam image_param(Int(32), 2, "image_param"); + Param scalar_param("scalar_param"); + + Var x("x"), y("y"); + Func f("f"); + + f(x, y) = image_param(x, y) + scalar_param; + + Buffer b(10, 10); + image_param.set(b); + + f.realize({10, 10}); + + printf("Success!\n"); + return 0; +}