Skip to content

Commit

Permalink
Fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Dec 9, 2024
1 parent c60ab2f commit 8f5f219
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/quantization/dequantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub(crate) fn dequantize_per_tensor_affine_int8_kernel(
#[comptime] scheme: QuantizationScheme,
) {
// Last two positions contain the qparams
if ABSOLUTE_POS >= output.len() - 2 {
if ABSOLUTE_POS >= input.len() - 2 {
return;
}

Expand Down
13 changes: 6 additions & 7 deletions crates/burn-jit/src/kernel/quantization/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ impl QParams {
QuantizationScheme::PerTensorAffine(_) => match comptime!(tensor.line_size()) {
// For line size of 1, scale is the last value in the buffer
1 => f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]),
// For any other line size > 1, scale and zero-point offset are the first two elements of the last line
_ => f32::bitcast_from(tensor[len - 1][tensor.line_size() - 2]),
// For any other line size > 1, scale and zero-point offset are the last two elements
_ => f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]),
},
// Symmetric quantization only contains the scaling factor as the last element
QuantizationScheme::PerTensorSymmetric(_) => {
Expand All @@ -40,13 +40,12 @@ impl QParams {
/// Get the zero-point offset.
pub fn offset(&self, tensor: &QTensor) -> i32 {
let len = tensor.len();
let line_size = comptime!(tensor.line_size());
match comptime!(self.scheme) {
QuantizationScheme::PerTensorAffine(_) => match line_size {
QuantizationScheme::PerTensorAffine(_) => match comptime!(tensor.line_size()) {
// For line size of 1, the zero-point offset is the penultimate value in the buffer
1 => i32::cast_from(tensor[len - 2][line_size]),
// For any other line size > 1, scale and zero-point offset are the first two elements of the last line
_ => i32::cast_from(tensor[len - 1][line_size]),
1 => i32::cast_from(tensor[len - 2][tensor.line_size() - 1]),
// For any other line size > 1, scale and zero-point offset are the last two elements
_ => i32::cast_from(tensor[len - 1][tensor.line_size() - 2]),
},
// Symmetric quantization only contains the scaling factor, so we return 0 for the zero-point offset
QuantizationScheme::PerTensorSymmetric(_) => 0,
Expand Down

0 comments on commit 8f5f219

Please sign in to comment.