Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cp.async access global tensor via pointer #2282

Merged
merged 24 commits into from
Mar 6, 2023
Merged

cp.async access global tensor via pointer #2282

merged 24 commits into from
Mar 6, 2023

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Dec 19, 2022

According to #1974, it looks like accessing global memory via char * delivers better performance. Plus, the base pointer is just an integral scalar and can be passed through the index hoisting and simplification pipeline just like any other scalar. It does not need to be handled separately.

Example kernel from NVFuserTest.FusionLargeWelfordNormalization_CUDA:

__global__ void kernel20(Tensor<float, 2> T0, Tensor<float, 1> T2, Tensor<float, 1> T4, Tensor<float, 1> T9, Tensor<int64_t, 1> T10, Tensor<float, 1> T11, Tensor<float, 1> T12, Tensor<int, 1> T13, Tensor<int64_t, 1> T14) {
  alignas(16) extern __shared__ char array[];
  void* shared_mem = array;
  nvfuser_index_t block_size = blockDim.x*blockDim.y*blockDim.z;
  float *shared_mem_var = static_cast<float*>(shared_mem);
  float *shared_mem_avg = shared_mem_var + block_size;
  float *shared_mem_n = shared_mem_avg + block_size;
  int i86;
  i86 = ((nvfuser_index_t)blockIdx.y) * T0.stride[0];
  int i87;
  i87 = ((nvfuser_index_t)blockIdx.x) * (ceilDiv((ceilDiv(T0.size[1], ((nvfuser_index_t)blockDim.x))), ((nvfuser_index_t)gridDim.x)));
  int i100;
  i100 = ((nvfuser_index_t)blockIdx.y) * 4;
  char* ptr101;
  ptr101 = (char*)(T4.data) + i100;
  char* ptr174;
  ptr174 = (char*)(T2.data) + i100;
  // Allocate global tensor T9
  // Allocate global tensor T10
  float T8[1];
  T8[0] = 0.00000000000000000e+00;
  #pragma unroll 1
  for(nvfuser_index_t i43 = 0; i43 < (ceilDiv((ceilDiv(T0.size[1], ((nvfuser_index_t)blockDim.x))), ((nvfuser_index_t)gridDim.x))); ++i43) {
    int i90;
    i90 = ((nvfuser_index_t)threadIdx.x) + (((nvfuser_index_t)blockDim.x) * (i87 + i43));
    if ((i90 < T0.size[1])) {
      T8[0]
        = T8[0]
        + *(float *)((char*)(T0.data) + (4 * (i86 + (T0.stride[1] * i90))));
    }
  }
  *(float *)ptr101 = 0.00000000000000000e+00;
  reduction::gridReduce<true, false, false, true, false, false, false>(
    *(float *)ptr101,
    T8[0],
    [](float &a, float b) { a = a + b; },
    &T9[0],
    &T10[0],
    static_cast<float*>(shared_mem),
    true,
    true,
    float(0.00000000000000000e+00),
    0,
    1);
  // Allocate global tensor T11
  // Allocate global tensor T12
  // Allocate global tensor T13
  // Allocate global tensor T14
  float T5[1];
  T5[0] = 0.00000000000000000e+00;
  float T7[1];
  T7[0] = 0.00000000000000000e+00;
  int T6[1];
  T6[0] = 0;
  #pragma unroll 1
  for(nvfuser_index_t i41 = 0; i41 < (ceilDiv((ceilDiv(T0.size[1], ((nvfuser_index_t)blockDim.x))), ((nvfuser_index_t)gridDim.x))); ++i41) {
    int i162;
    i162 = ((nvfuser_index_t)threadIdx.x) + (((nvfuser_index_t)blockDim.x) * (i87 + i41));
    if ((i162 < T0.size[1])) {
      welfordCombine (
        T5[0],
        T7[0],
        T6[0],
        *(float *)((char*)(T0.data) + (4 * (i86 + (T0.stride[1] * i162)))),
        (float)0,
        (int)1);
    }
  }
  float T1[1];
  T1[0] = 0.00000000000000000e+00;
  *(float *)ptr174 = 0.00000000000000000e+00;
  int T3[1];
  T3[0] = 0;
  float block_result_avg_0 = 0.00000000000000000e+00;
  float block_result_var_0 = 0.00000000000000000e+00;
  int block_result_n_0 = 0;
  blockWelford<true, false, false>(
    block_result_avg_0,
    block_result_var_0,
    block_result_n_0,
    T5[0],
    float(T7[0]),
    int(T6[0]),
    threadIdx,
    blockDim,
    reinterpret_cast<float*>(shared_mem_avg),
    reinterpret_cast<float*>(shared_mem_var),
    reinterpret_cast<int*>(shared_mem_n),
    true,
    true,
    float(0));
  welford::gridWelford<true, false, false, false, true, true, false>(
    T1[0],
    *(float *)ptr174,
    T3[0],
    block_result_avg_0,
    block_result_var_0,
    block_result_n_0,
    &T11[0],
    &T12[0],
    &T13[0],
    T14,
    reinterpret_cast<float*>(shared_mem_avg),
    reinterpret_cast<float*>(shared_mem_var),
    reinterpret_cast<int*>(shared_mem_n),
    true,
    true,
    float(0),
    0,
    1);
}

TODO: run benchmark

Comment on lines 313 to 318
tensor_most_positive_index += (tensor_input.size(dim_i) - 1) *
tensor_input.stride(dim_i) * tensor_input.element_size();
} else {
// Acuumulate negative stride
tensor_most_negative_index +=
(tensor_input.size(dim_i) - 1) * tensor_input.stride(dim_i);
tensor_most_negative_index += (tensor_input.size(dim_i) - 1) *
tensor_input.stride(dim_i) * tensor_input.element_size();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WARNING: index mode change. Some 32-bit indexable kernels now will need 64-bit indexing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this change?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worthwhile to do this if it requires switching the indexing for an entire kernel from 32-bit to 64-bit?

@zasdfgbnm
Copy link
Collaborator Author

zasdfgbnm commented Dec 21, 2022

Overall, the perf is fine:
image
However, I do see a few real regressions caused by 32bit indexing -> 64bit indexing switch

NvFuserScheduler_TIMM_BatchNorm_nhwc_fp16___GRAPH/NvFuserScheduler_TIMM_BatchNorm_nhwc_fp16/256/184/112/manual_time
NvFuserScheduler_TIMM_BatchNorm_nhwc_fp16___GRAPH/NvFuserScheduler_TIMM_BatchNorm_nhwc_fp16/256/200/112/manual_time
NvFuserScheduler_TIMM_BatchNorm_nhwc_fp16___GRAPH/NvFuserScheduler_TIMM_BatchNorm_nhwc_fp16/128/368/112/manual_time
NvFuserScheduler_TIMM_BatchNorm_nhwc_fp16___GRAPH/NvFuserScheduler_TIMM_BatchNorm_nhwc_fp16/2048/152/56/manual_time

I think the only way to solve this regression is to restrict this PR to be effective only for cpAsync, however, this would complicate collectIndexMode because it needs to know if there is a cpAsync in the fusion.

@naoyam
Copy link
Collaborator

naoyam commented Jan 30, 2023

@zasdfgbnm What's the status of this PR?

@zasdfgbnm
Copy link
Collaborator Author

@zasdfgbnm What's the status of this PR?

I haven't get time to work on this yet. Will work on it this week.

@zasdfgbnm zasdfgbnm changed the title Access global tensor via byte pointer cp.async access global tensor via pointer Mar 2, 2023
@zasdfgbnm zasdfgbnm marked this pull request as ready for review March 2, 2023 01:37
@zasdfgbnm
Copy link
Collaborator Author

I changed this PR to use T* instead of char*. So I don't think this will have 32bit indexing issue anymore. I am not sure if this will give us as good register usage as char*, but In this PR, I prefer to focus on building out the infrastructure, and I will come back to register usage later.

@zasdfgbnm zasdfgbnm requested review from naoyam and csarofeen March 2, 2023 01:41
@naoyam
Copy link
Collaborator

naoyam commented Mar 6, 2023

Does this only affect tensors read with cp.async as suggested by the PR title?

@zasdfgbnm
Copy link
Collaborator Author

Does this only affect tensors read with cp.async as suggested by the PR title?

Right, should only affect cp.async. Other reads are not affected.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few comments for now. Trying to remember what this PR is about.

TORCH_CUDA_CU_API bool isIntegralType(DataType dtype);
// Returns if the datatype is a pointer type
TORCH_CUDA_CU_API bool isPointerType(DataType dtype);
// Returns if the datatype is an boolean type
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "a boolean" (thanks for fixing the misplaced comments)

third_party/nvfuser/csrc/lower_index.cpp Show resolved Hide resolved
@naoyam
Copy link
Collaborator

naoyam commented Mar 6, 2023

Does this only affect tensors read with cp.async as suggested by the PR title?

Right, should only affect cp.async. Other reads are not affected.

Where is this logic implemented?

@naoyam
Copy link
Collaborator

naoyam commented Mar 6, 2023

Which tests would show the pointer addressing?

@zasdfgbnm
Copy link
Collaborator Author

Does this only affect tensors read with cp.async as suggested by the PR title?

Right, should only affect cp.async. Other reads are not affected.

Where is this logic implemented?

It is controlled by the generate_pointer parameter, which is only true in IndexLowering::handle(const LoadStoreOp* ldst)

@zasdfgbnm
Copy link
Collaborator Author

Which tests would show the pointer addressing?

There is no test for it. But I can change some existing tests to check this.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. It would be great to add some check to a test that should use this pointer addressing.

third_party/nvfuser/csrc/ir_builder.cpp Outdated Show resolved Hide resolved
third_party/nvfuser/csrc/codegen.cpp Outdated Show resolved Hide resolved
@zasdfgbnm zasdfgbnm merged commit 16a26a1 into devel Mar 6, 2023
@zasdfgbnm zasdfgbnm deleted the access-gmem-ptr branch March 6, 2023 21:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants