Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DataType to better support complicated dtypes like array and pointer #2417

Merged
merged 7 commits into from
Feb 22, 2023

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Feb 6, 2023

I am currently refactoring #2282, and feel that we need to support pointers of arbitrary type. So I did a refactor of DataType. The original DataType is renamed as PrimDataType, and the new DataType is now a wrapper of std::variant<PrimDataType, ArrayOf, PointerOf>. In our python frontend, we only expose PrimDataType because non-prim data types are for codegen internal usage, and are not expected to be used by users.

@zasdfgbnm zasdfgbnm marked this pull request as ready for review February 6, 2023 09:10
third_party/nvfuser/csrc/type.cpp Show resolved Hide resolved
third_party/nvfuser/csrc/type.cpp Show resolved Hide resolved
third_party/nvfuser/csrc/type.cpp Show resolved Hide resolved
struct DataType;

struct ArrayOf {
std::shared_ptr<DataType> type;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this needs to be a pointer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because DataType is a forward declaration. We need to know the definition of DataType to define ArrayOf, but we need the definition of ArrayOf to define std::variant<ArrayOf, ...>, so we need to use pointer to workaround this cyclic dependency.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I just feel this seems like a (mentally) significantly different from the current DataType, which is just a plain enum.

I'm curious if the type field of ArrayOf and PointersOf needs to be DataType. Do we need to be able to express arrays of pointers, etc?

Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Feb 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably yes. In matmul kernel, there are things like

char *base_ptr[2] = {...};
#pragma unroll
for(i : {0, 1, 2, 3}) {
  for(j : {0, 1}) {
    char *address = base_ptr[j] + func(i)

where base_ptr is a array of pointer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably yes. In matmul kernel, there are things like

char *base_ptr[2] = {...};
#pragma unroll
for(i : {0, 1, 2, 3}) {
  for(j : {0, 1}) {
    char *address = base_ptr[j] + func(i)

where base_ptr is a array of pointer.

I am not sure if the SASS will be different if we do it differently, but still it is good to support this case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, the kernel profiler uses a TensorView as a profile buffer:

https://github.com/csarofeen/pytorch/blob/devel/third_party/nvfuser/csrc/lower_instrument.cpp#L82

Can we replace it with an array?

No. I think this is an example of when we don't want to use an array, because we want it to be located on global memory. I expect an array to be something like when you write std::array<type, size> variable;. So it has to be on local memory of each thread.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this PR support actually creating arrays? Can we have tests?

I think it is always supported, and this PR does not change it:

PYTORCH_NVFUSER_DUMP="cuda_kernel" ./build/bin/nvfuser_tests --gtest_filter=*ViewAs*
__global__ void kernel1(Tensor<std::complex<float>, 1> T0, Tensor<std::complex<float>, 1> T1, Tensor<float, 2> T2, Tensor<float, 2> T7) {
  NVFUSER_DEFINE_MAGIC_ZERO
  int i45;
  i45 = ((nvfuser_index_t)threadIdx.x) * T2.stride[0];
  int i59;
  i59 = ((nvfuser_index_t)threadIdx.x) * 2;
  std::complex<float> T4[1];
  T4[0] = 0;
  T4[0]
    = T0[(((nvfuser_index_t)threadIdx.x) * T0.stride[0])]
    + T1[(((nvfuser_index_t)threadIdx.x) * T1.stride[0])];
  Array<float, 2, 1> T5[1];
  T5[0]
     = erase_type(T4[0]);
  #pragma unroll
  for(nvfuser_index_t i23 = 0; i23 < 2; ++i23) {
    int i36;
    i36 = i23 + nvfuser_zero;
    float T3[1];
    T3[0]
      = T2[(i45 + (T2.stride[1] * i36))]
      + (float) 1.00000000000000000e+00;
    float T6[1];
    T6[0] = T5[0][i23];
    T7[(i59 + i36)]
      = T3[0]
      + T6[0];
  }
  NVFUSER_UPDATE_MAGIC_ZERO
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a tensor of arrays, right? Isn't it possible to just create an array?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a tensor of arrays, right? Isn't it possible to just create an array?

I don't think we have the infrastructure ready to allocate an array and assign value to it right now. But I don't mind stacking another PR on top of this to demonstrate what I want to do, and how I will use it. It's already on my plate, and hopefully will provide a better justification of this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, that's fine. I'll approve the PR. Thanks for the discussion!

third_party/nvfuser/csrc/executor_params.h Show resolved Hide resolved
@@ -24,7 +24,7 @@ void initNvFuserPythonBindings(PyObject* module) {
auto nvfuser = py::handle(module).cast<py::module>();

//! DataTypes supported by nvFuser in the FusionDefinition
py::enum_<Nvf::DataType>(nvfuser, "DataType")
py::enum_<Nvf::PrimDataType>(nvfuser, "DataType")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are renaming it here, I'm surprised that I didn't see the renaming on python side? Are we passing python tests?!

I'd like to keep this renaming on the cpp side if that makes sense. I think we can get away with mapping the rename in our python wrapper, so we don't have to change a bunch of user scripts.

cc'ing @kevinstephano for opinion on whether to expose this to user.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I checked python tests, and they are passing. The python side is not renamed. I think C++ generally doesn't care about names, so there is nothing that prevents you from exposing a class BigAndBeautiful as "SmallAndUgly" on python.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving the PR for the codegen side. For the frontend part, I assume @jjsjann123 will give a stamp.

@@ -836,7 +837,7 @@ struct CastOpRecord : RecordFunctor {
_name,
RecordType::CastOp),
fusion_op_(fusion_op),
dtype_(dtype) {}
dtype_(std::get<Nvf::PrimDataType>(dtype.type)) {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the Python side is generally using the PrimDataType to expose data types. Why did you leave the leave the RecordFunctor child class constructors passing in the Nvf::DataType instead of have them pass in the Nvf::PrimDataType?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I changed RecordFunctor subclasses to use PrimDataType.

@zasdfgbnm zasdfgbnm merged commit b7c866b into devel Feb 22, 2023
@zasdfgbnm zasdfgbnm deleted the type-system branch February 22, 2023 23:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants