Skip to content

Commit

Permalink
Prevent use of uninitialized scalar Parameters in JIT code (halide#7847
Browse files Browse the repository at this point in the history
…, partial) (halide#7853)

* Prevent use of uninitialized scalar Parameters in JIT code (halide#7847, partial)

* Fix broken tests

* Update Parameter.h

* Update func_clone.cpp

* Fix Generators too

* Fixes

* Update InferArguments.cpp

* Fixes

* pacify clang-tidy

* fixes
  • Loading branch information
steven-johnson authored and ardier committed Mar 3, 2024
1 parent e921910 commit cf60895
Show file tree
Hide file tree
Showing 14 changed files with 213 additions and 66 deletions.
5 changes: 0 additions & 5 deletions python_bindings/src/halide/halide_/PyParameter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ void define_parameter(py::module &m) {
.def(py::init<const Parameter &>(), py::arg("p"))
.def(py::init<const Type &, bool, int>())
.def(py::init<const Type &, bool, int, const std::string &>())
.def(py::init<const Type &, bool, int, const std::string &,
const Buffer<void> &, int, const std::vector<BufferConstraint> &,
MemoryType>())
.def(py::init<const Type &, bool, int, const std::string &,
uint64_t, const Expr &, const Expr &, const Expr &, const Expr &>())
.def("_to_argument", [](const Parameter &p) -> Argument {
return Argument(p.name(),
p.is_buffer() ? Argument::InputBuffer : Argument::InputScalar,
Expand Down
16 changes: 13 additions & 3 deletions src/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1159,14 +1159,24 @@ Parameter Deserializer::deserialize_parameter(const Serialize::Parameter *parame
deserialize_vector<Serialize::BufferConstraint, BufferConstraint>(parameter->buffer_constraints(),
&Deserializer::deserialize_buffer_constraint);
const auto memory_type = deserialize_memory_type(parameter->memory_type());
return Parameter(type, is_buffer, dimensions, name, Buffer<>(), host_alignment, buffer_constraints, memory_type);
return Parameter(type, dimensions, name, Buffer<>(), host_alignment, buffer_constraints, memory_type);
} else {
const uint64_t data = parameter->data();
static_assert(FLATBUFFERS_USE_STD_OPTIONAL);
const auto make_optional_halide_scalar_value_t = [](const std::optional<uint64_t> &v) -> std::optional<halide_scalar_value_t> {
if (v.has_value()) {
halide_scalar_value_t scalar_data;
scalar_data.u.u64 = v.value();
return std::optional<halide_scalar_value_t>(scalar_data);
} else {
return std::nullopt;
}
};
const std::optional<halide_scalar_value_t> scalar_data = make_optional_halide_scalar_value_t(parameter->scalar_data());
const auto scalar_default = deserialize_expr(parameter->scalar_default_type(), parameter->scalar_default());
const auto scalar_min = deserialize_expr(parameter->scalar_min_type(), parameter->scalar_min());
const auto scalar_max = deserialize_expr(parameter->scalar_max_type(), parameter->scalar_max());
const auto scalar_estimate = deserialize_expr(parameter->scalar_estimate_type(), parameter->scalar_estimate());
return Parameter(type, is_buffer, dimensions, name, data, scalar_default, scalar_min, scalar_max, scalar_estimate);
return Parameter(type, dimensions, name, scalar_data, scalar_default, scalar_min, scalar_max, scalar_estimate);
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/Generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2001,7 +2001,8 @@ class GeneratorInput_Scalar : public GeneratorInputImpl<T, Expr> {

void set_def_min_max() override {
for (Parameter &p : this->parameters_) {
p.set_scalar<TBase>(def_);
// No: we want to leave the Parameter unset here.
// p.set_scalar<TBase>(def_);
p.set_default_value(def_expr_);
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/InferArguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "ExternFuncArgument.h"
#include "Function.h"
#include "IROperator.h"
#include "IRVisitor.h"
#include "InferArguments.h"

Expand Down Expand Up @@ -197,7 +198,9 @@ class InferArguments : public IRGraphVisitor {

ArgumentEstimates argument_estimates = p.get_argument_estimates();
if (!p.is_buffer()) {
argument_estimates.scalar_def = p.scalar_expr();
// We don't want to crater here if a scalar param isn't set;
// instead, default to a zero of the right type, like we used to.
argument_estimates.scalar_def = p.has_scalar_value() ? p.scalar_expr() : make_zero(p.type());
argument_estimates.scalar_min = p.min_value();
argument_estimates.scalar_max = p.max_value();
argument_estimates.scalar_estimate = p.estimate();
Expand Down
90 changes: 64 additions & 26 deletions src/Parameter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct ParameterContents {
const int dimensions;
const std::string name;
Buffer<> buffer;
uint64_t data = 0;
std::optional<halide_scalar_value_t> scalar_data;
int host_alignment;
std::vector<BufferConstraint> buffer_constraints;
Expr scalar_default, scalar_min, scalar_max, scalar_estimate;
Expand Down Expand Up @@ -82,21 +82,21 @@ Parameter::Parameter(const Type &t, bool is_buffer, int d, const std::string &na
internal_assert(is_buffer || d == 0) << "Scalar parameters should be zero-dimensional";
}

Parameter::Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name,
Parameter::Parameter(const Type &t, int dimensions, const std::string &name,
const Buffer<void> &buffer, int host_alignment, const std::vector<BufferConstraint> &buffer_constraints,
MemoryType memory_type)
: contents(new Internal::ParameterContents(t, is_buffer, dimensions, name)) {
: contents(new Internal::ParameterContents(t, /*is_buffer*/ true, dimensions, name)) {
contents->buffer = buffer;
contents->host_alignment = host_alignment;
contents->buffer_constraints = buffer_constraints;
contents->memory_type = memory_type;
}

Parameter::Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name,
uint64_t data, const Expr &scalar_default, const Expr &scalar_min,
Parameter::Parameter(const Type &t, int dimensions, const std::string &name,
const std::optional<halide_scalar_value_t> &scalar_data, const Expr &scalar_default, const Expr &scalar_min,
const Expr &scalar_max, const Expr &scalar_estimate)
: contents(new Internal::ParameterContents(t, is_buffer, dimensions, name)) {
contents->data = data;
: contents(new Internal::ParameterContents(t, /*is_buffer*/ false, dimensions, name)) {
contents->scalar_data = scalar_data;
contents->scalar_default = scalar_default;
contents->scalar_min = scalar_min;
contents->scalar_max = scalar_max;
Expand All @@ -123,51 +123,55 @@ bool Parameter::is_buffer() const {
return contents->is_buffer;
}

bool Parameter::has_scalar_value() const {
return defined() && !contents->is_buffer && contents->scalar_data.has_value();
}

Expr Parameter::scalar_expr() const {
check_is_scalar();
const auto sv = scalar_data_checked();
const Type t = type();
if (t.is_float()) {
switch (t.bits()) {
case 16:
if (t.is_bfloat()) {
return Expr(scalar<bfloat16_t>());
return Expr(bfloat16_t::make_from_bits(sv.u.u16));
} else {
return Expr(scalar<float16_t>());
return Expr(float16_t::make_from_bits(sv.u.u16));
}
case 32:
return Expr(scalar<float>());
return Expr(sv.u.f32);
case 64:
return Expr(scalar<double>());
return Expr(sv.u.f64);
}
} else if (t.is_int()) {
switch (t.bits()) {
case 8:
return Expr(scalar<int8_t>());
return Expr(sv.u.i8);
case 16:
return Expr(scalar<int16_t>());
return Expr(sv.u.i16);
case 32:
return Expr(scalar<int32_t>());
return Expr(sv.u.i32);
case 64:
return Expr(scalar<int64_t>());
return Expr(sv.u.i64);
}
} else if (t.is_uint()) {
switch (t.bits()) {
case 1:
return Internal::make_bool(scalar<bool>());
return Internal::make_bool(sv.u.b);
case 8:
return Expr(scalar<uint8_t>());
return Expr(sv.u.u8);
case 16:
return Expr(scalar<uint16_t>());
return Expr(sv.u.u16);
case 32:
return Expr(scalar<uint32_t>());
return Expr(sv.u.u32);
case 64:
return Expr(scalar<uint64_t>());
return Expr(sv.u.u64);
}
} else if (t.is_handle()) {
// handles are always uint64 internally.
switch (t.bits()) {
case 64:
return Expr(scalar<uint64_t>());
return Expr(sv.u.u64);
}
}
internal_error << "Unsupported type " << t << " in scalar_expr\n";
Expand Down Expand Up @@ -198,14 +202,48 @@ void Parameter::set_buffer(const Buffer<> &b) {
contents->buffer = b;
}

void *Parameter::scalar_address() const {
const void *Parameter::read_only_scalar_address() const {
check_is_scalar();
return &contents->data;
// Use explicit if here (rather than user_assert) so that we don't
// have to disable bugprone-unchecked-optional-access in clang-tidy,
// which is a useful check.
const auto &sv = contents->scalar_data;
if (sv.has_value()) {
return std::addressof(sv.value());
} else {
user_error << "Parameter " << name() << " does not have a valid scalar value.\n";
return nullptr;
}
}

std::optional<halide_scalar_value_t> Parameter::scalar_data() const {
return defined() ? contents->scalar_data : std::nullopt;
}

uint64_t Parameter::scalar_raw_value() const {
halide_scalar_value_t Parameter::scalar_data_checked() const {
check_is_scalar();
return contents->data;
// Use explicit if here (rather than user_assert) so that we don't
// have to disable bugprone-unchecked-optional-access in clang-tidy,
// which is a useful check.
halide_scalar_value_t result;
const auto &sv = contents->scalar_data;
if (sv.has_value()) {
result = sv.value();
} else {
user_error << "Parameter " << name() << " does not have a valid scalar value.\n";
result.u.u64 = 0; // silence "possibly uninitialized" compiler warning
}
return result;
}

halide_scalar_value_t Parameter::scalar_data_checked(const Type &val_type) const {
check_type(val_type);
return scalar_data_checked();
}

void Parameter::set_scalar(const Type &val_type, halide_scalar_value_t val) {
check_type(val_type);
contents->scalar_data = std::optional<halide_scalar_value_t>(val);
}

/** Tests if this handle is the same as another handle */
Expand Down
71 changes: 45 additions & 26 deletions src/Parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
/** \file
* Defines the internal representation of parameters to halide piplines
*/
#include <optional>
#include <string>

#include "Buffer.h"
Expand All @@ -25,7 +26,9 @@ struct BufferConstraint {
};

namespace Internal {

#ifdef WITH_SERIALIZATION
class Deserializer;
class Serializer;
#endif
struct ParameterContents;
Expand All @@ -45,21 +48,41 @@ class Parameter {
Internal::IntrusivePtr<Internal::ParameterContents> contents;

#ifdef WITH_SERIALIZATION
friend class Internal::Serializer; //< for scalar_raw_value()
friend class Internal::Deserializer; //< for scalar_data()
friend class Internal::Serializer; //< for scalar_data()
#endif
friend class Pipeline; //< for scalar_address()
friend class Pipeline; //< for read_only_scalar_address()

/** Get the raw currently-bound buffer. null if unbound */
const halide_buffer_t *raw_buffer() const;

/** Get the pointer to the current value of the scalar
* parameter. For a given parameter, this address will never
* change. Only relevant when jitting. */
void *scalar_address() const;
* change. Note that this can only be used to *read* from -- it must
* not be written to, so don't cast away the constness. Only relevant when jitting. */
const void *read_only_scalar_address() const;

/** If the Parameter is a scalar, and the scalar data is valid, return
* the scalar data. Otherwise, return nullopt. */
std::optional<halide_scalar_value_t> scalar_data() const;

/** If the Parameter is a scalar and has a valid scalar value, return it.
* Otherwise, assert-fail. */
halide_scalar_value_t scalar_data_checked() const;

/** If the Parameter is a scalar *of the given type* and has a valid scalar value, return it.
* Otherwise, assert-fail. */
halide_scalar_value_t scalar_data_checked(const Type &val_type) const;

/** Get the raw data of the current value of the scalar
* parameter. Only relevant when serializing. */
uint64_t scalar_raw_value() const;
/** Construct a new buffer parameter via deserialization. */
Parameter(const Type &t, int dimensions, const std::string &name,
const Buffer<void> &buffer, int host_alignment, const std::vector<BufferConstraint> &buffer_constraints,
MemoryType memory_type);

/** Construct a new scalar parameter via deserialization. */
Parameter(const Type &t, int dimensions, const std::string &name,
const std::optional<halide_scalar_value_t> &scalar_data, const Expr &scalar_default, const Expr &scalar_min,
const Expr &scalar_max, const Expr &scalar_estimate);

public:
/** Construct a new undefined handle */
Expand All @@ -81,15 +104,6 @@ class Parameter {
* explicitly specified (as opposed to autogenerated). */
Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name);

/** Construct a new parameter via deserialization. */
Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name,
const Buffer<void> &buffer, int host_alignment, const std::vector<BufferConstraint> &buffer_constraints,
MemoryType memory_type);

Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name,
uint64_t data, const Expr &scalar_default, const Expr &scalar_min,
const Expr &scalar_max, const Expr &scalar_estimate);

Parameter(const Parameter &) = default;
Parameter &operator=(const Parameter &) = default;
Parameter(Parameter &&) = default;
Expand All @@ -111,28 +125,33 @@ class Parameter {
* bound value. Only relevant when jitting */
template<typename T>
HALIDE_NO_USER_CODE_INLINE T scalar() const {
check_type(type_of<T>());
return *((const T *)(scalar_address()));
static_assert(sizeof(T) <= sizeof(halide_scalar_value_t));
const auto sv = scalar_data_checked(type_of<T>());
T t;
memcpy(&t, &sv.u.u64, sizeof(t));
return t;
}

/** This returns the current value of scalar<type()>()
* as an Expr. */
/** This returns the current value of scalar<type()>() as an Expr.
* If the Parameter is not scalar, or its scalar data is not valid, this will assert-fail. */
Expr scalar_expr() const;

/** This returns true if scalar_expr() would return a valid Expr,
* false if not. */
bool has_scalar_value() const;

/** If the parameter is a scalar parameter, set its current
* value. Only relevant when jitting */
template<typename T>
HALIDE_NO_USER_CODE_INLINE void set_scalar(T val) {
check_type(type_of<T>());
*((T *)(scalar_address())) = val;
halide_scalar_value_t sv;
memcpy(&sv.u.u64, &val, sizeof(val));
set_scalar(type_of<T>(), sv);
}

/** If the parameter is a scalar parameter, set its current
* value. Only relevant when jitting */
HALIDE_NO_USER_CODE_INLINE void set_scalar(const Type &val_type, halide_scalar_value_t val) {
check_type(val_type);
memcpy(scalar_address(), &val, val_type.bytes());
}
void set_scalar(const Type &val_type, halide_scalar_value_t val);

/** If the parameter is a buffer parameter, get its currently
* bound buffer. Only relevant when jitting */
Expand Down
2 changes: 1 addition & 1 deletion src/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ void Pipeline::prepare_jit_call_arguments(RealizationArg &outputs, const Target
}
debug(2) << "JIT input ImageParam argument ";
} else {
args_result.store[arg_index++] = p.scalar_address();
args_result.store[arg_index++] = p.read_only_scalar_address();
debug(2) << "JIT input scalar argument ";
}
}
Expand Down
10 changes: 8 additions & 2 deletions src/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1298,13 +1298,19 @@ Offset<Serialize::Parameter> Serializer::serialize_parameter(FlatBufferBuilder &
return Serialize::CreateParameter(builder, defined, is_buffer, type_serialized, dimensions, name_serialized, host_alignment,
builder.CreateVector(buffer_constraints_serialized), memory_type_serialized);
} else {
const uint64_t data = parameter.scalar_raw_value();
static_assert(FLATBUFFERS_USE_STD_OPTIONAL);
const auto make_optional_u64 = [](const std::optional<halide_scalar_value_t> &v) -> std::optional<uint64_t> {
return v.has_value() ?
std::optional<uint64_t>(v.value().u.u64) :
std::nullopt;
};
const auto scalar_data = make_optional_u64(parameter.scalar_data());
const auto scalar_default_serialized = serialize_expr(builder, parameter.default_value());
const auto scalar_min_serialized = serialize_expr(builder, parameter.min_value());
const auto scalar_max_serialized = serialize_expr(builder, parameter.max_value());
const auto scalar_estimate_serialized = serialize_expr(builder, parameter.estimate());
return Serialize::CreateParameter(builder, defined, is_buffer, type_serialized,
dimensions, name_serialized, 0, 0, Serialize::MemoryType_Auto, data,
dimensions, name_serialized, 0, 0, Serialize::MemoryType_Auto, scalar_data,
scalar_default_serialized.first, scalar_default_serialized.second,
scalar_min_serialized.first, scalar_min_serialized.second,
scalar_max_serialized.first, scalar_max_serialized.second,
Expand Down
2 changes: 1 addition & 1 deletion src/halide_ir.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ table Parameter {
host_alignment: int32;
buffer_constraints: [BufferConstraint];
memory_type: MemoryType;
data: uint64;
scalar_data: uint64 = null; // Note: it is valid for this to be omitted, even if is_buffer = false.
scalar_default: Expr;
scalar_min: Expr;
scalar_max: Expr;
Expand Down
Loading

0 comments on commit cf60895

Please sign in to comment.