Skip to content

Commit

Permalink
fix kits and mmm defs
Browse files Browse the repository at this point in the history
  • Loading branch information
kali authored and mathieupoumeyrolsonos committed Dec 9, 2024
1 parent 12c4169 commit 79b8479
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 13 deletions.
2 changes: 1 addition & 1 deletion linalg/src/frame/mmm/kit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ impl MMMKit {
.1
.downcast_ref::<PackedFormat>()
.is_some_and(|pf| KitDatumType::from(pf.dt) == self.activation),
"Activation packecd mismatch {self:?} {:?}",
"Activation packed dt mismatch {self:?} {:?}",
mmm.packings()[packing].1
);
self.items.push(MMMKitItem { mmm, packing, weight_panel_extractor });
Expand Down
23 changes: 14 additions & 9 deletions linalg/src/generic/mmm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,16 +334,21 @@ where
0
}

const PQ40_R4: PackedBlockQuantFormat = PackedBlockQuantFormat::new(&Q4_0, 4, 0, false);
const PQ40_R4_SE: PackedBlockQuantFormat = PackedBlockQuantFormat::new(&Q4_0, 4, 0, true);
fn pq40_r4() -> PackedBlockQuantFormat {
PackedBlockQuantFormat::new(&Q4_0, 4, 0, false)
}

fn pq40_r4_se() -> PackedBlockQuantFormat {
PackedBlockQuantFormat::new(&Q4_0, 4, 0, true)
}

// f16 kernels
MMMRustKernel!(kernel::<f16, 4, 4> => generic_f16_4x4<f16>(4,4) store(f32, f64));
MMMRustKernel! {kernel::<f16, 4, 1> => generic_f16_4x1<f16>(4,1)
packing[1] = f32f32 => |k| k.with_packing(f32::packing(4), f32::packing(1));
packing[2] = q40f16 => |k| k.with_packing_a(PQ40_R4);
packing[3] = q40f16se => |k| k.with_packing_a(PQ40_R4_SE);
packing[4] = q40f32 => |k| k.with_packing(PQ40_R4, f32::packing(1));
packing[2] = q40f16 => |k| k.with_packing_a(pq40_r4());
packing[3] = q40f16se => |k| k.with_packing_a(pq40_r4_se());
packing[4] = q40f32 => |k| k.with_packing(pq40_r4(), f32::packing(1));
store(f32, f64)
}

Expand All @@ -354,9 +359,9 @@ MMMRustKernel!(kernel::<f32, 4, 4> => generic_f32_4x4<f32>(4,4)
);
MMMRustKernel! {kernel::<f32, 4, 1> => generic_f32_4x1<f32>(4,1)
packing[1] = f16f16 => |k| k.with_packing(f16::packing(4), f16::packing(1));
packing[2] = q40f16 => |k| k.with_packing(PQ40_R4, f16::packing(1));
packing[3] = q40f16se => |k| k.with_packing(PQ40_R4_SE, f16::packing(1));
packing[4] = q40f32 => |k| k.with_packing_a(PQ40_R4);
packing[2] = q40f16 => |k| k.with_packing(pq40_r4(), f16::packing(1));
packing[3] = q40f16se => |k| k.with_packing(pq40_r4_se(), f16::packing(1));
packing[4] = q40f32 => |k| k.with_packing_a(pq40_r4());
store(f16, f64)
}

Expand Down Expand Up @@ -387,7 +392,7 @@ MMMRustKernel! {kernel::<i32, 3, 2> => generic_i32_3x2<i32>(3,2)

pub fn plug(ops: &mut Ops) {
ops.mmm_kits.push(
MMMKit::new(Q4_0, F32, F32, &PQ40_R4)
MMMKit::new(Q4_0, F32, F32, &pq40_r4())
.with_native(generic_f32_4x1.mmm(), 4)
.with_generic_fallback(true),
);
Expand Down
2 changes: 1 addition & 1 deletion linalg/src/x86_64_fma/mmm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ MMMExternKernel! {fma_mmm_f32_32x1<f32>(32,1)@(32,4) where(FMA)
store(f16)
}
MMMExternKernel!(fma_mmm_f32_32x3<f32>(32,3)@(32,4) where(FMA)
packing[1] = f32f16 => |k| k.with_packing(PackedFormat::new(F32, 32, 32), PackedFormat::new(F16, 3, 2));
packing[1] = f32f16 => |k| k.with_packing(f32::packing(32).align(32), f16::packing(3));
store(f16)
);

Expand Down
5 changes: 3 additions & 2 deletions linalg/src/x86_64_fma/panel_extract.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::*;
use crate::frame::PackedFormat;
use crate::frame::mmm::Packing;
use crate::Ops;
use tract_data::internal::*;

Expand All @@ -9,12 +10,12 @@ pub fn plug(ops: &mut Ops) {

panel_extractor!(kernel_packed_32_q40_to_f32 as packed_32_q40_to_f32(
Box::new(super::mmm::pq40_r32()),
PackedFormat::new(f32::datum_type(), 32, 32)
f32::packing(32).align(32)
) where(AVX2));

panel_extractor!(kernel_packed_32_f16_to_f32 as packed_32_f16_to_f32(
Box::new(PackedFormat::new(f16::datum_type(), 32, 32)),
PackedFormat::new(f32::datum_type(), 32, 32)
f32::packing(32).align(32)
) where(AVX2));

#[target_feature(enable = "avx2")]
Expand Down

0 comments on commit 79b8479

Please sign in to comment.