Skip to content

Commit

Permalink
Add more quantization support for burn-jit (#2275)
Browse files Browse the repository at this point in the history
* Add cubecl quantization kernels and QTensorOps for burn-jit

* Fix typo

* Fix output vec factor

* Fix output dtype size_of

* Remove unused code in dequantize test

* Fix dequantize vectorization

* Handle tensors when number of elems is not a multiple of 4

* Support quantize for tensors with less than 4 elems (no vectorization)

* Fix equal 0 test

* Add quantize/dequantize tests

* Add q_to_device

* Refactor kernels for latest cubecl

* intermediate i32 cast

* Fix size_of output type

* Use strict=false to ignore floating point precision issues with qparams equality

* Only check that lhs & rhs strategies match (but not strict on qparams values)

* Use assert_approx_eq on dequant values

* Reduce precision for flaky test

* Remove todo comment

* Add comment for cast to unsigned

* More comment

---------

Co-authored-by: louisfd <[email protected]>
  • Loading branch information
laggui and louisfd authored Sep 17, 2024
1 parent 834005e commit aa79e36
Show file tree
Hide file tree
Showing 12 changed files with 744 additions and 50 deletions.
3 changes: 2 additions & 1 deletion crates/burn-jit/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ where
type FloatTensorPrimitive<const D: usize> = JitTensor<R, Self::FloatElem, D>;
type IntTensorPrimitive<const D: usize> = JitTensor<R, Self::IntElem, D>;
type BoolTensorPrimitive<const D: usize> = JitTensor<R, u32, D>;
type QuantizedTensorPrimitive<const D: usize> = QJitTensor<R, D>;
type QuantizedTensorPrimitive<const D: usize> =
QJitTensor<R, Self::FloatElem, Self::IntElem, D>;

fn name() -> String {
format!("jit<{}>", R::name())
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-jit/src/kernel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ pub mod matmul;
pub mod pool;
/// Pseudo-random number generator kernels
pub mod prng;
/// Quantization operations
pub mod quantization;
/// Reduction algorithms
pub mod reduce;

Expand Down
211 changes: 211 additions & 0 deletions crates/burn-jit/src/kernel/quantization/dequantize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
use crate::tensor::{JitTensor, QJitTensor};
use crate::FloatElement;
use crate::{IntElement, JitElement, JitRuntime};
use burn_tensor::quantization::{QuantizationScheme, QuantizationType};
use cubecl::calculate_cube_count_elemwise;
use cubecl::prelude::*;

#[cube]
pub(crate) fn dequantize_affine_int8<F: Float>(value: i32, scale: F, offset: i32) -> F {
// x = scale * (x_q - offset)
scale * (F::cast_from(value) - F::cast_from(offset))
}

#[cube]
pub(crate) fn extract_i8(value: u32, offset: u32) -> i32 {
// Extract 8-bit segment
let value = (value >> offset) & 0xFF;
// Check if the value is negative by inspecting the MSB and subtract 256 if it is
// Subtract 0 or 256 to circumvent unsupported conditional assignment (let x = if {} else {};)
let sub = i32::cast_from(value & 0x80 != 0) * 256;
i32::cast_from(value) - sub
}

#[cube(launch_unchecked)]
pub(crate) fn dequantize_per_tensor_affine_int8_kernel(
input: &Tensor<u32>,
scale: &Tensor<f32>,
offset: &Tensor<i32>,
output: &mut Tensor<f32>,
#[comptime] vectorized: bool,
) {
if ABSOLUTE_POS >= output.len() {
return;
}

let scale = scale[0];
let offset = offset[0];

let num_packed = 4;
let value = input[ABSOLUTE_POS];
let output_pos = ABSOLUTE_POS * num_packed;

if vectorized {
let vectorization_factor = vectorization_of(input);
#[unroll]
for i in 0..vectorization_factor {
// Extract each 8-bit segment
let v1 = extract_i8(value[i], 24);
let v2 = extract_i8(value[i], 16);
let v3 = extract_i8(value[i], 8);
let v4 = extract_i8(value[i], 0);

output[output_pos * vectorization_factor + i * num_packed] =
dequantize_affine_int8::<f32>(v1, scale, offset);
output[output_pos * vectorization_factor + i * num_packed + 1] =
dequantize_affine_int8::<f32>(v2, scale, offset);
output[output_pos * vectorization_factor + i * num_packed + 2] =
dequantize_affine_int8::<f32>(v3, scale, offset);
output[output_pos * vectorization_factor + i * num_packed + 3] =
dequantize_affine_int8::<f32>(v4, scale, offset);
}
} else {
// Extract each 8-bit segment
let v1 = extract_i8(value, 24);
let v2 = extract_i8(value, 16);
let v3 = extract_i8(value, 8);
let v4 = extract_i8(value, 0);

output[output_pos] = dequantize_affine_int8::<f32>(v1, scale, offset);
output[output_pos + 1] = dequantize_affine_int8::<f32>(v2, scale, offset);
output[output_pos + 2] = dequantize_affine_int8::<f32>(v3, scale, offset);
output[output_pos + 3] = dequantize_affine_int8::<f32>(v4, scale, offset);
}
}

#[cube]
pub(crate) fn dequantize_symmetric_int8<F: Float>(value: i32, scale: F) -> F {
// x = scale * x_q
scale * F::cast_from(value)
}

// Would have wrapped symmetric with the same affine kernel but cube doesn't support Option<Tensor> for offset.
#[cube(launch_unchecked)]
pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel(
input: &Tensor<u32>,
scale: &Tensor<f32>,
output: &mut Tensor<f32>,
#[comptime] vectorized: bool,
) {
if ABSOLUTE_POS >= output.len() {
return;
}

let scale = scale[0];

let num_packed = 4;
let value = input[ABSOLUTE_POS];
let output_pos = ABSOLUTE_POS * num_packed;

if vectorized {
let vectorization_factor = vectorization_of(input);
#[unroll]
for i in 0..vectorization_factor {
for j in 0..num_packed {
let output_idx = output_pos * vectorization_factor + i * num_packed + j;
if output_idx >= output.len() {
return; // value not quantized (padding)
}
// Extract each 8-bit segment
let v = extract_i8(value[i], (3 - j) * 8);
output[output_idx] = dequantize_symmetric_int8::<f32>(v, scale);
}
}
} else {
// Extract each 8-bit segment
for j in 0..num_packed {
let output_idx = output_pos + j;
if output_idx >= output.len() {
return; // value not quantized (padding)
}
// Extract each 8-bit segment
let v = extract_i8(value, (3 - j) * 8);
output[output_pos + j] = dequantize_symmetric_int8::<f32>(v, scale);
}
}
}

pub(crate) fn dequantize_per_tensor<R, F, I, const D: usize>(
tensor: JitTensor<R, u32, D>,
scale: JitTensor<R, F, 1>,
offset: Option<JitTensor<R, I, 1>>,
) -> JitTensor<R, F, D>
where
R: JitRuntime,
F: JitElement,
I: IntElement,
{
// The actual number of elements is 1/4 (four int8 values packed in a single u32)
// so we choose a vectorization factor to match a valid input binding size.
let num_out_elems = tensor.shape.num_elements();
let num_elems = usize::div_ceil(num_out_elems, 4);
let vectorization_factor = [4u8, 2, 1]
.iter()
.filter_map(|&v| {
if num_elems >= v as usize {
Some(v)
} else {
None
}
})
.next()
.unwrap();
let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);

let shape_output = tensor.shape.clone();
let client = tensor.client.clone();
let handle = client.empty(num_out_elems * core::mem::size_of::<F>());
let output =
JitTensor::new_contiguous(client.clone(), tensor.device.clone(), shape_output, handle);

let dummy_array = [1; D];
if let Some(offset) = offset {
unsafe {
dequantize_per_tensor_affine_int8_kernel::launch_unchecked::<R>(
&client,
cube_count,
cube_dim,
tensor.as_tensor_arg(vectorization_factor),
// Ignore shape and stride
TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1),
TensorArg::from_raw_parts(&offset.handle, &dummy_array, &dummy_array, 1),
output.as_tensor_arg(1),
vectorization_factor > 1,
)
};
} else {
unsafe {
dequantize_per_tensor_symmetric_int8_kernel::launch_unchecked::<R>(
&client,
cube_count,
cube_dim,
tensor.as_tensor_arg(vectorization_factor),
// Ignore shape and stride
TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1),
output.as_tensor_arg(1),
vectorization_factor > 1,
)
};
}

output
}

/// Convert the tensor back to a higher precision data type.
pub fn dequantize<R, F, I, const D: usize>(tensor: QJitTensor<R, F, I, D>) -> JitTensor<R, F, D>
where
R: JitRuntime,
F: FloatElement,
I: IntElement,
{
match tensor.scheme {
QuantizationScheme::PerTensorAffine(dtype)
| QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
QuantizationType::QInt8 => {
dequantize_per_tensor(tensor.qtensor, tensor.qparams.scale, tensor.qparams.offset)
}
},
}
}
5 changes: 5 additions & 0 deletions crates/burn-jit/src/kernel/quantization/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod dequantize;
mod quantize;

pub use dequantize::*;
pub use quantize::*;
Loading

0 comments on commit aa79e36

Please sign in to comment.