Skip to content

Commit

Permalink
[ET-VK] [ET-VK] Reduced int precision for all int storage in conv pw …
Browse files Browse the repository at this point in the history
…op to improve performance.

Differential Revision: [D67674212](https://our.internmc.facebook.com/intern/diff/D67674212/)

ghstack-source-id: 259643761
Pull Request resolved: #7447
  • Loading branch information
trivedivivek committed Dec 27, 2024
1 parent ed15042 commit a2f8811
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,37 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require

/*
* Computes a depthwise convolution. Each shader invocation calculates the
* output at a single output location.
*/
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const u16vec3 pos = u16vec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, out_limits))) {
return;
}

// Compute the index of the top-left element of the overlay region. Negative
// indices indicate that the top-left element is in a region added by padding.
const ivec2 ipos = pos.xy * stride - padding;
const u16vec2 ipos = pos.xy * u16vec2(stride) - u16vec2(padding);

// Compute the start and end of the input indices to load. Padding is assumed
// to be constant 0 padding, so any reads from the padding region is skipped.
const ivec2 start = ipos;
const ivec2 end = ipos + overlay_region.xy;
const u16vec2 start = ipos;
const u16vec2 end = ipos + u16vec2(overlay_region.xy);

VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
int kx = 0;
for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) {
for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) {
VEC4_T sum = texelFetch(t_bias, u16vec2(pos.z, 0), 0);
uint16_t kx = uint16_t(0);
for (uint16_t y = start.y, i = uint16_t(0); i < uint16_t(TILE_SIZE); y += uint16_t(dilation.y), i++) {
for (uint16_t x = start.x, j = uint16_t(0); j < uint16_t(TILE_SIZE); x += uint16_t(dilation.x), j++) {
// The weight kernel was rearranged such that every NxN filter is
// flattened to fit in one row. Each filter was then stacked on top of
// each other vertically.
const vec4 in_texel = texelFetch(t_in, ivec3(x, y, pos.z), 0);
sum = fma(in_texel, texelFetch(t_kernel, ivec2(kx, pos.z), 0), sum);
const vec4 in_texel = texelFetch(t_in, u16vec3(x, y, pos.z), 0);
sum = fma(in_texel, texelFetch(t_kernel, u16vec2(kx, pos.z), 0), sum);
kx++;
}
}
Expand Down

0 comments on commit a2f8811

Please sign in to comment.