Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DataType to better support complicated dtypes like array and pointer #2417

Merged
merged 7 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion third_party/nvfuser/csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
// The type of an integer literal is automatically picked from
// int, long int, and long long int, so no suffix should be
// required. https://en.cppreference.com/w/cpp/language/integer_literal
switch (dtype) {
switch (std::get<PrimDataType>(dtype.type)) {
case DataType::Float:
return "f";
default:
Expand Down
6 changes: 3 additions & 3 deletions third_party/nvfuser/csrc/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ template <typename T>
void Val::dispatch(T handler, Val* val) {
switch (*(val->getValType())) {
case ValType::Scalar:
switch (*(val->getDataType())) {
switch (std::get<PrimDataType>(val->getDataType()->type)) {
case DataType::Bool:
ptr(handler)->handle(val->as<Bool>());
return;
Expand Down Expand Up @@ -294,7 +294,7 @@ template <typename T>
void Val::constDispatch(T handler, const Val* val) {
switch (*(val->getValType())) {
case ValType::Scalar:
switch (*(val->getDataType())) {
switch (std::get<PrimDataType>(val->getDataType()->type)) {
case DataType::Bool:
ptr(handler)->handle(val->as<Bool>());
return;
Expand Down Expand Up @@ -562,7 +562,7 @@ template <typename T>
void Val::mutatorDispatch(T mutator, Val* val) {
switch (*(val->getValType())) {
case ValType::Scalar:
switch (*(val->getDataType())) {
switch (std::get<PrimDataType>(val->getDataType()->type)) {
case DataType::Bool:
ptr(mutator)->mutate(val->as<Bool>());
return;
Expand Down
2 changes: 1 addition & 1 deletion third_party/nvfuser/csrc/executor_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
namespace nvfuser {

struct TORCH_CUDA_CU_API CompileParams {
DataType index_type = DataType::Int;
PrimDataType index_type = DataType::Int;
naoyam marked this conversation as resolved.
Show resolved Hide resolved
int maxrregcount = 255;
bool enable_magic_zero = true;

Expand Down
2 changes: 1 addition & 1 deletion third_party/nvfuser/csrc/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace nvfuser {
namespace {

bool equals(const Val* value, const EvaluatorValue& concrete_value) {
switch (value->getDataType().value()) {
switch (std::get<PrimDataType>(value->getDataType()->type)) {
case DataType::Int: {
if (!concrete_value.isInt()) {
return false;
Expand Down
2 changes: 1 addition & 1 deletion third_party/nvfuser/csrc/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
namespace nvfuser {

Val* IrBuilder::newScalar(DataType dtype) {
switch (dtype) {
switch (std::get<PrimDataType>(dtype.type)) {
case DataType::Bool:
return IrBuilder::create<Bool>();
case DataType::Float:
Expand Down
2 changes: 1 addition & 1 deletion third_party/nvfuser/csrc/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class TORCH_CUDA_CU_API IrBuilder {

template <typename T>
Val* IrBuilder::newConstant(T value, DataType dtype) {
switch (dtype) {
switch (std::get<PrimDataType>(dtype.type)) {
case DataType::Bool:
return IrBuilder::create<Bool>((bool)value);
case DataType::Float:
Expand Down
2 changes: 1 addition & 1 deletion third_party/nvfuser/csrc/ir_interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ template <typename UnderlyingType>
class TORCH_CUDA_CU_API Scalar : public Val {
public:
using ScalarType = UnderlyingType;
static constexpr DataType kDefaultDataType =
static constexpr PrimDataType kDefaultDataType =
NativeTypeToDataType<UnderlyingType>::type;

explicit Scalar(IrBuilderPasskey passkey, DataType dtype = kDefaultDataType)
Expand Down
2 changes: 1 addition & 1 deletion third_party/nvfuser/csrc/lower_predicate_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ Val* PredicateElimination::getInitValue(TensorView* tv) const {
if (init_val == nullptr) {
// No reduction restriction. Just use zero
auto dtype = *tv->getDataType();
if (isVectorType(dtype)) {
if (std::holds_alternative<ArrayOf>(dtype.type)) {
return IrBuilder::create<NamedScalar>("{}", dtype);
}
return GpuLower::current()->kernel()->zeroVal();
Expand Down
11 changes: 3 additions & 8 deletions third_party/nvfuser/csrc/ops/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ TensorView* reductionOpRaw(
Val* init,
TensorView* tv,
bool keep_dim /*=false*/,
DataType dtype /* DataType::Null */) {
DataType dtype /* DataType::Null */) {
// TODO: should we use squeeze for size 1 broadcast dim?

TORCH_CHECK(
Expand Down Expand Up @@ -2376,13 +2376,8 @@ TensorView* gather(

TensorView* viewAsScalar(TensorView* inp) {
auto inp_type = inp->getDataType().value();
TORCH_CHECK(
isVectorType(inp_type),
"Invalid type to viewAsScalar. A vector type is expected but ",
inp_type,
" is given.");
int vec_size = getVectorSizeFromType(inp_type);
auto out_type = getTypeFromVectorType(inp_type);
int vec_size = std::get<ArrayOf>(inp_type.type).size;
auto out_type = *std::get<ArrayOf>(inp_type.type).type;

std::vector<IterDomain*> out_domain;
auto inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain());
Expand Down
3 changes: 2 additions & 1 deletion third_party/nvfuser/csrc/ops/composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ TensorView* view_as_real(TensorView* x) {
isComplexType(input_type),
"Operand of view_as_real must have complex type");

auto vec_type = getVectorType(getTypeFromComplexType(input_type), 2);
auto vec_type = ArrayOf{
std::make_shared<DataType>(getTypeFromComplexType(input_type)), 2};
auto tv_vector = bitCastOp(vec_type, x);
return viewAsScalar(tv_vector);
}
Expand Down
6 changes: 3 additions & 3 deletions third_party/nvfuser/csrc/ops/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Val* newScalar(ValType vtype, DataType dtype) {
switch (vtype) {
case (ValType::NamedScalar):
case (ValType::Scalar):
switch (dtype) {
switch (std::get<PrimDataType>(dtype.type)) {
case DataType::Bool:
return IrBuilder::create<Bool>();
case DataType::Float:
Expand Down Expand Up @@ -320,7 +320,7 @@ Val* newValLike(Val* val, DataType dtype) {
// lowest value for integer type;
// false for bool.
Val* getMinimumValue(DataType v) {
switch (v) {
switch (std::get<PrimDataType>(v.type)) {
case (DataType::Double):
return IrBuilder::create<Double>(
-std::numeric_limits<double>::infinity());
Expand Down Expand Up @@ -357,7 +357,7 @@ Val* getMinimumValue(DataType v) {
// highest value for integer type;
// true for bool.
Val* getMaximumValue(DataType v) {
switch (v) {
switch (std::get<PrimDataType>(v.type)) {
case (DataType::Double):
return IrBuilder::create<Double>(std::numeric_limits<double>::infinity());
break;
Expand Down
2 changes: 1 addition & 1 deletion third_party/nvfuser/csrc/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2416,7 +2416,7 @@ class IrParser {
TORCH_INTERNAL_ASSERT(
dim_value.has_value(), "dim in softmax is not valid");

auto data_type = DataType::Null;
DataType data_type = DataType::Null;
if (const auto opt_ivalue = toIValue(node->input(2))) {
if (!opt_ivalue->isNone()) {
data_type = aten_to_data_type(opt_ivalue->toScalarType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using namespace nvfuser::inst;

namespace nvfuser::python_frontend {

const char* dtypeToPyString(nvfuser::DataType t) {
const char* dtypeToPyString(nvfuser::PrimDataType t) {
switch (t) {
case nvfuser::DataType::Bool:
return "DataType.Bool";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct UserSchedule;
//! This is helper function used to print a python formated
//! Fusion IR DataType when printing a fusion definition.

TORCH_CUDA_CU_API const char* dtypeToPyString(nvfuser::DataType t);
TORCH_CUDA_CU_API const char* dtypeToPyString(nvfuser::PrimDataType t);

//! The State and the StateType enum are used to define state objects to
//! encapsulate the recording of state in the FusionDefinition.
Expand Down
29 changes: 15 additions & 14 deletions third_party/nvfuser/csrc/python_frontend/fusion_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <algorithm>
#include <complex>
#include <variant>

namespace nvfuser::python_frontend {

Expand Down Expand Up @@ -838,7 +839,7 @@ struct CastOpRecord : RecordFunctor {
std::vector<State> _outputs,
std::string _name,
std::function<OutType(DataType, ArgType)> fusion_op,
DataType dtype)
PrimDataType dtype)
: RecordFunctor(
std::move(_args),
std::move(_outputs),
Expand Down Expand Up @@ -912,14 +913,14 @@ struct CastOpRecord : RecordFunctor {
//! nvFuser arith function signature
std::function<OutType(DataType, ArgType)> fusion_op_;
//! Type to cast to.
DataType dtype_;
PrimDataType dtype_;
};

//! Specialized Record Functor for recording FusionDefinition constant state.

template <typename ExprType, typename ValueType>
struct ConstantRecord : RecordFunctor {
ConstantRecord(std::vector<State> _outputs, ValueType val, DataType dtype)
ConstantRecord(std::vector<State> _outputs, ValueType val, PrimDataType dtype)
: RecordFunctor(
{},
std::move(_outputs),
Expand Down Expand Up @@ -980,7 +981,7 @@ struct ConstantRecord : RecordFunctor {
ValueType value_;

//! The DataType provided
DataType dtype_;
PrimDataType dtype_;
};

//! Specialized Record Functor for recording FusionDefinition End.
Expand Down Expand Up @@ -1018,7 +1019,7 @@ struct TensorRecord : RecordFunctor {
std::vector<State> _outputs,
std::vector<int64_t> _symbolic_sizes,
std::vector<bool> _contiguous_info,
DataType _dtype,
PrimDataType _dtype,
bool _is_cpu = false)
: RecordFunctor(
{},
Expand Down Expand Up @@ -1148,7 +1149,7 @@ struct TensorRecord : RecordFunctor {
//! with the dimension just to its right.
std::vector<bool> contiguous_info_;
//! Tensor data type.
DataType dtype_;
PrimDataType dtype_;
//! Notes a scalar CPU Tensor
bool is_cpu_;
};
Expand Down Expand Up @@ -1217,7 +1218,7 @@ struct ReductionOpRecord : RecordFunctor {
fusion_op,
std::vector<int> axes,
bool keep_dim,
DataType dtype)
PrimDataType dtype)
: RecordFunctor(
std::move(_args),
std::move(_outputs),
Expand Down Expand Up @@ -1336,7 +1337,7 @@ struct ReductionOpRecord : RecordFunctor {
//! Indicates whether to keep the reduced dimension(s).
bool keep_dim_;
//! The output data type.
DataType dtype_;
PrimDataType dtype_;
};

struct IndexSelectOpRecord : RecordFunctor {
Expand Down Expand Up @@ -1416,7 +1417,7 @@ struct TorchGatherOpRecord : RecordFunctor {
//! Specialized Record Functor for recording FusionDefinition input scalars.

struct ScalarRecord : RecordFunctor {
ScalarRecord(std::vector<State> _outputs, DataType dtype)
ScalarRecord(std::vector<State> _outputs, PrimDataType dtype)
: RecordFunctor(
{},
std::move(_outputs),
Expand Down Expand Up @@ -1472,7 +1473,7 @@ struct ScalarRecord : RecordFunctor {

private:
//! Scalar data type.
DataType dtype_;
PrimDataType dtype_;
};

//! Specialized Record Functor for recording FusionDefinition Start.
Expand Down Expand Up @@ -1765,7 +1766,7 @@ struct FullOpRecord : RecordFunctor {
std::vector<State> _args,
std::vector<State> _outputs,
std::vector<int64_t>& shape,
DataType dtype)
PrimDataType dtype)
: RecordFunctor(
std::move(_args),
std::move(_outputs),
Expand Down Expand Up @@ -1836,14 +1837,14 @@ struct FullOpRecord : RecordFunctor {
//! Represents shape of new tensor
std::vector<int64_t> shape_;
//! Type of output
DataType dtype_;
PrimDataType dtype_;
};

struct IotaOpRecord : RecordFunctor {
IotaOpRecord(
std::vector<State> _args,
std::vector<State> _outputs,
DataType dtype)
PrimDataType dtype)
: RecordFunctor(
std::move(_args),
std::move(_outputs),
Expand Down Expand Up @@ -1893,7 +1894,7 @@ struct IotaOpRecord : RecordFunctor {

private:
//! Type of output
DataType dtype_;
PrimDataType dtype_;
};

} // namespace nvfuser::python_frontend
Expand Down
Loading