-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor DataType to better support complicated dtypes like array and pointer #2417
Conversation
struct DataType; | ||
|
||
struct ArrayOf { | ||
std::shared_ptr<DataType> type; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this needs to be a pointer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because DataType
is a forward declaration. We need to know the definition of DataType
to define ArrayOf
, but we need the definition of ArrayOf
to define std::variant<ArrayOf, ...>
, so we need to use pointer to workaround this cyclic dependency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. I just feel this seems like a (mentally) significantly different from the current DataType
, which is just a plain enum.
I'm curious if the type
field of ArrayOf
and PointersOf
needs to be DataType
. Do we need to be able to express arrays of pointers, etc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably yes. In matmul kernel, there are things like
char *base_ptr[2] = {...};
#pragma unroll
for(i : {0, 1, 2, 3}) {
for(j : {0, 1}) {
char *address = base_ptr[j] + func(i)
where base_ptr
is a array of pointer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably yes. In matmul kernel, there are things like
char *base_ptr[2] = {...}; #pragma unroll for(i : {0, 1, 2, 3}) { for(j : {0, 1}) { char *address = base_ptr[j] + func(i)where
base_ptr
is a array of pointer.
I am not sure if the SASS will be different if we do it differently, but still it is good to support this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, the kernel profiler uses a
TensorView
as a profile buffer:https://github.com/csarofeen/pytorch/blob/devel/third_party/nvfuser/csrc/lower_instrument.cpp#L82
Can we replace it with an array?
No. I think this is an example of when we don't want to use an array, because we want it to be located on global memory. I expect an array to be something like when you write std::array<type, size> variable;
. So it has to be on local memory of each thread.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this PR support actually creating arrays? Can we have tests?
I think it is always supported, and this PR does not change it:
PYTORCH_NVFUSER_DUMP="cuda_kernel" ./build/bin/nvfuser_tests --gtest_filter=*ViewAs*
__global__ void kernel1(Tensor<std::complex<float>, 1> T0, Tensor<std::complex<float>, 1> T1, Tensor<float, 2> T2, Tensor<float, 2> T7) {
NVFUSER_DEFINE_MAGIC_ZERO
int i45;
i45 = ((nvfuser_index_t)threadIdx.x) * T2.stride[0];
int i59;
i59 = ((nvfuser_index_t)threadIdx.x) * 2;
std::complex<float> T4[1];
T4[0] = 0;
T4[0]
= T0[(((nvfuser_index_t)threadIdx.x) * T0.stride[0])]
+ T1[(((nvfuser_index_t)threadIdx.x) * T1.stride[0])];
Array<float, 2, 1> T5[1];
T5[0]
= erase_type(T4[0]);
#pragma unroll
for(nvfuser_index_t i23 = 0; i23 < 2; ++i23) {
int i36;
i36 = i23 + nvfuser_zero;
float T3[1];
T3[0]
= T2[(i45 + (T2.stride[1] * i36))]
+ (float) 1.00000000000000000e+00;
float T6[1];
T6[0] = T5[0][i23];
T7[(i59 + i36)]
= T3[0]
+ T6[0];
}
NVFUSER_UPDATE_MAGIC_ZERO
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a tensor of arrays, right? Isn't it possible to just create an array?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a tensor of arrays, right? Isn't it possible to just create an array?
I don't think we have the infrastructure ready to allocate an array and assign value to it right now. But I don't mind stacking another PR on top of this to demonstrate what I want to do, and how I will use it. It's already on my plate, and hopefully will provide a better justification of this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, that's fine. I'll approve the PR. Thanks for the discussion!
@@ -24,7 +24,7 @@ void initNvFuserPythonBindings(PyObject* module) { | |||
auto nvfuser = py::handle(module).cast<py::module>(); | |||
|
|||
//! DataTypes supported by nvFuser in the FusionDefinition | |||
py::enum_<Nvf::DataType>(nvfuser, "DataType") | |||
py::enum_<Nvf::PrimDataType>(nvfuser, "DataType") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are renaming it here, I'm surprised that I didn't see the renaming on python side? Are we passing python tests?!
I'd like to keep this renaming on the cpp side if that makes sense. I think we can get away with mapping the rename in our python wrapper, so we don't have to change a bunch of user scripts.
cc'ing @kevinstephano for opinion on whether to expose this to user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I checked python tests, and they are passing. The python side is not renamed. I think C++ generally doesn't care about names, so there is nothing that prevents you from exposing a class BigAndBeautiful
as "SmallAndUgly"
on python.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving the PR for the codegen side. For the frontend part, I assume @jjsjann123 will give a stamp.
@@ -836,7 +837,7 @@ struct CastOpRecord : RecordFunctor { | |||
_name, | |||
RecordType::CastOp), | |||
fusion_op_(fusion_op), | |||
dtype_(dtype) {} | |||
dtype_(std::get<Nvf::PrimDataType>(dtype.type)) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the Python side is generally using the PrimDataType
to expose data types. Why did you leave the leave the RecordFunctor
child class constructors passing in the Nvf::DataType
instead of have them pass in the Nvf::PrimDataType
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. I changed RecordFunctor
subclasses to use PrimDataType
.
I am currently refactoring #2282, and feel that we need to support pointers of arbitrary type. So I did a refactor of
DataType
. The originalDataType
is renamed asPrimDataType
, and the newDataType
is now a wrapper ofstd::variant<PrimDataType, ArrayOf, PointerOf>
. In our python frontend, we only exposePrimDataType
because non-prim data types are for codegen internal usage, and are not expected to be used by users.