Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization Passes for dynamic_gather #184

Open
avik-pal opened this issue Dec 12, 2024 · 2 comments
Open

Optimization Passes for dynamic_gather #184

avik-pal opened this issue Dec 12, 2024 · 2 comments

Comments

@avik-pal
Copy link
Collaborator

module {
  func.func @main(%arg0: tensor<6x6xf64>) -> tensor<6x6xf64> {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<6x6xf64>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<6x6xf64>
    %c = stablehlo.constant dense<[[1, 0], [2, 1], [3, 2], [4, 3], [5, 4], [0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [0, 1], [1, 2], [2, 3], [3, 4], [4, 5]]> : tensor<16x2xi64>
    %c_1 = stablehlo.constant dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5]]> : tensor<5x2xi64>
    %c_2 = stablehlo.constant dense<[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]> : tensor<6x2xi64>
    %c_3 = stablehlo.constant dense<1> : tensor<2xi64>
    %c_4 = stablehlo.constant dense<[[1, 0], [2, 1], [3, 2], [4, 3], [5, 4]]> : tensor<5x2xi64>
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<6x6xf64>) -> tensor<6x6xf64>
    %1 = "stablehlo.dynamic_gather"(%0, %c_4, %c_3) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>}> : (tensor<6x6xf64>, tensor<5x2xi64>, tensor<2xi64>) -> tensor<5xf64>
    %2 = "stablehlo.dynamic_gather"(%0, %c_2, %c_3) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>}> : (tensor<6x6xf64>, tensor<6x2xi64>, tensor<2xi64>) -> tensor<6xf64>
    %3 = "stablehlo.dynamic_gather"(%0, %c_1, %c_3) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>}> : (tensor<6x6xf64>, tensor<5x2xi64>, tensor<2xi64>) -> tensor<5xf64>
    %4 = stablehlo.concatenate %1, %2, %3, dim = 0 : (tensor<5xf64>, tensor<6xf64>, tensor<5xf64>) -> tensor<16xf64>
    %5 = "stablehlo.scatter"(%cst_0, %c, %4) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
    ^bb0(%arg1: tensor<f64>, %arg2: tensor<f64>):
      stablehlo.return %arg2 : tensor<f64>
    }) : (tensor<6x6xf64>, tensor<16x2xi64>, tensor<16xf64>) -> tensor<6x6xf64>
    %6 = stablehlo.add %5, %cst : tensor<6x6xf64>
    %7 = stablehlo.transpose %6, dims = [1, 0] : (tensor<6x6xf64>) -> tensor<6x6xf64>
    return %7 : tensor<6x6xf64>
  }
}

This comes from fn(x) = Tridiagonal(x) .+ 1.

Essentially if we fuse the dynamic_gathers into dynamic_gather + slice, then the slice + concatenate will get eliminated by another pass.

(In this particular case even the dynamic_gather + scatter can be eliminated, but that is probably much harder to write a pass for)

@wsmoses
Copy link
Member

wsmoses commented Dec 14, 2024

so there's several opts here worth doing.

  • dynamic gather of transpose -> rewritten dynamic gather
  • concat of consecutive dynamic gathers is replaced with a larger dynamic gather
  • eventually scatter of dynamic gather
  • dynamic_gather -> gather for static slice_sizes

@avik-pal avik-pal changed the title Fuse multiple dynamic_gather into a single dynamic_gather + multiple slice? Optimization Passes for dynamic_gather Dec 15, 2024
@glou-nes
Copy link

glou-nes commented Dec 28, 2024

@wsmoses I have a question about "eventually scatter of dynamic gather", do we want to replace scatter + gather by a mask such as that; I have no idea how to deal with more complex transformation. @avik-pal do you have ideas for that?

%c = stablehlo.constant dense<[[1, 0], [2, 1], [3, 2], [4, 3], [5, 4], [0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [0, 1], [1, 2], [2, 3], [3, 4], [4, 5]]> : tensor<16x2xi64>
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<6x6xf64>) -> tensor<6x6xf64>
%1 = "stablehlo.gather"(%0, %c) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<6x6xf64>, tensor<16x2xi64>) -> tensor<16xf64>
%2 = "stablehlo.scatter"(%cst_0, %c, %1) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> (

=>

%c = stablehlo.constant dense<[[1, 0], [2, 1], [3, 2], [4, 3], [5, 4], [0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [0, 1], [1, 2], [2, 3], [3, 4], [4, 5]]> : tensor<16x2xi64>
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<6x6xf64>) -> tensor<6x6xf64>
%C =  sparse matrix following %c indices: tensor<6x6xf64>
%2 = stablehlo.multiply(%0,%C): tensor<6x6xf64>

Thanks in advance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants