diff --git a/core/src/ops/einsum/kernel_selection.rs b/core/src/ops/einsum/kernel_selection.rs index dd655a3f1f..99732acd14 100644 --- a/core/src/ops/einsum/kernel_selection.rs +++ b/core/src/ops/einsum/kernel_selection.rs @@ -98,12 +98,20 @@ pub fn wire_linear( } else { KitDatumType::F32 }; + let activation = match b_fact.datum_type { + DatumType::F16 => KitDatumType::F16, + DatumType::F32 => KitDatumType::F32, + _ => todo!(), + }; let kit = tract_linalg::ops() .mmm_kits() .iter() - .filter(|kit| kit.weight == weight && kit.accumulator == accumulator) + .filter(|kit| { + kit.weight == weight && kit.accumulator == accumulator && kit.activation == activation + }) .min_by_key(|kit| kit.generic_fallback as usize) .with_context(|| format!("No kit found for matmul {a:?} • {b_fact:?}"))?; + let configs = [kit.item_for_mv(), kit.item_for_squarish()]; let packed: Box = if let Some(a_payload) = a_as_bqv { let packed = kit .static_packer @@ -117,8 +125,6 @@ pub fn wire_linear( let konst = tensor0(Opaque::from(packed)); let pa = patch.add_const(format!("{prefix}.pack_a"), konst)?; - let configs = [kit.item_for_mv(), kit.item_for_squarish()]; - let packers = configs .iter() .map(|conf| { diff --git a/linalg/src/frame/mmm/kit.rs b/linalg/src/frame/mmm/kit.rs index 70038578c1..8cdd996527 100644 --- a/linalg/src/frame/mmm/kit.rs +++ b/linalg/src/frame/mmm/kit.rs @@ -2,8 +2,9 @@ use std::fmt::Debug; use tract_data::prelude::DatumType; -use crate::frame::block_quant::BlockQuant; +use crate::frame::block_quant::{BlockQuant, PackedBlockQuantFormat}; +use super::pack::PackedFormat; use super::panel_extract::PanelExtractor; use super::{MMMInputFormat, MatMatMul}; @@ -98,19 +99,44 @@ impl MMMKit { activation: impl Into, static_packer: &dyn MMMInputFormat, ) -> MMMKit { - MMMKit { - weight: weight.into(), - accumulator: accumulator.into(), - activation: activation.into(), + let (weight, accumulator, activation) = + (weight.into(), accumulator.into(), activation.into()); + let kit = MMMKit { + weight, + accumulator, + activation, static_packer: dyn_clone::clone_box(static_packer), items: vec![], generic_fallback: false, - } + }; + match &kit.weight { + WeightType::Plain(p) => { + debug_assert!( + kit.static_packer.downcast_ref::().is_some_and(|pf| pf.dt == *p), + "Static packer not compatible with weight format {kit:?}" + ) + } + WeightType::BlockQuant(bq) => debug_assert!( + kit.static_packer + .downcast_ref::() + .is_some_and(|pbqf| pbqf.bq.same_as(&**bq)), + "Static packer not compatible with weight format {kit:?}" + ), + }; + kit } pub(crate) fn with_native(mut self, mmm: Box, packing: usize) -> Self { - assert!(mmm.packings()[packing].0.same_as(&*self.static_packer)); - assert!(self.accumulator == mmm.internal_type().into()); + debug_assert!( + mmm.packings()[packing].0.same_as(&*self.static_packer), + "Weight packing mismatch {self:?} {mmm:?}/{packing} {:?}", + mmm.packings()[packing].0 + ); + debug_assert!( + self.accumulator == mmm.internal_type().into(), + "Accumulator mismatch {self:?} {mmm:?}/{packing} {:?}", + mmm.packings()[packing].0 + ); self.items.push(MMMKitItem { mmm, packing, weight_panel_extractor: None }); self } @@ -121,7 +147,21 @@ impl MMMKit { packing: usize, weight_panel_extractor: PanelExtractor, ) -> Self { - assert!(self.accumulator == mmm.internal_type().into()); + debug_assert!( + self.accumulator == mmm.internal_type().into(), + "Accumulator mismatch {self:?} {mmm:?}/{packing} {:?}", + mmm.packings()[packing].0 + ); + debug_assert!( + self.static_packer.same_as(&*weight_panel_extractor.from), + "Static weight packing/extractor mismatch {self:?} {mmm:?}/{packing} {:?} {weight_panel_extractor:?}", + mmm.packings()[packing].0 + ); + debug_assert!( + weight_panel_extractor.to.same_as(&*mmm.packings()[packing].0), + "Extractor/kernel packing mismatch {self:?} {mmm:?}/{packing} {:?} {weight_panel_extractor:?}", + mmm.packings()[packing].0 + ); self.items.push(MMMKitItem { mmm, packing, diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index 1664fc8d16..0b0bc3181f 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -60,10 +60,12 @@ pub fn plug(ops: &mut Ops) { .with_native(fma_mmm_f32_32x1.mmm(), 2) .with_extracting(fma_mmm_f32_32x3.mmm(), 1, packed_32_q40_to_f32.clone()), ); - ops.mmm_kits.push(MMMKit::new(F16, F32, F16, &PQ40_R32).with_extracting( - fma_mmm_f32_32x3.mmm(), - 1, - packed_32_f16_to_f32.clone(), - )); + ops.mmm_kits.push( + MMMKit::new(F16, F32, F16, &PackedFormat::new(F16, 32, 32)).with_extracting( + fma_mmm_f32_32x3.mmm(), + 1, + packed_32_f16_to_f32.clone(), + ), + ); } }