Skip to content

Commit

Permalink
fix q40f16
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Oct 23, 2024
1 parent 285d198 commit 576134a
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 17 deletions.
12 changes: 9 additions & 3 deletions core/src/ops/einsum/kernel_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,20 @@ pub fn wire_linear(
} else {
KitDatumType::F32
};
let activation = match b_fact.datum_type {
DatumType::F16 => KitDatumType::F16,
DatumType::F32 => KitDatumType::F32,
_ => todo!(),
};
let kit = tract_linalg::ops()
.mmm_kits()
.iter()
.filter(|kit| kit.weight == weight && kit.accumulator == accumulator)
.filter(|kit| {
kit.weight == weight && kit.accumulator == accumulator && kit.activation == activation
})
.min_by_key(|kit| kit.generic_fallback as usize)
.with_context(|| format!("No kit found for matmul {a:?} • {b_fact:?}"))?;
let configs = [kit.item_for_mv(), kit.item_for_squarish()];
let packed: Box<dyn MMMInputValue> = if let Some(a_payload) = a_as_bqv {
let packed = kit
.static_packer
Expand All @@ -117,8 +125,6 @@ pub fn wire_linear(
let konst = tensor0(Opaque::from(packed));
let pa = patch.add_const(format!("{prefix}.pack_a"), konst)?;

let configs = [kit.item_for_mv(), kit.item_for_squarish()];

let packers = configs
.iter()
.map(|conf| {
Expand Down
58 changes: 49 additions & 9 deletions linalg/src/frame/mmm/kit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use std::fmt::Debug;

use tract_data::prelude::DatumType;

use crate::frame::block_quant::BlockQuant;
use crate::frame::block_quant::{BlockQuant, PackedBlockQuantFormat};

use super::pack::PackedFormat;
use super::panel_extract::PanelExtractor;
use super::{MMMInputFormat, MatMatMul};

Expand Down Expand Up @@ -98,19 +99,44 @@ impl MMMKit {
activation: impl Into<KitDatumType>,
static_packer: &dyn MMMInputFormat,
) -> MMMKit {
MMMKit {
weight: weight.into(),
accumulator: accumulator.into(),
activation: activation.into(),
let (weight, accumulator, activation) =
(weight.into(), accumulator.into(), activation.into());
let kit = MMMKit {
weight,
accumulator,
activation,
static_packer: dyn_clone::clone_box(static_packer),
items: vec![],
generic_fallback: false,
}
};
match &kit.weight {
WeightType::Plain(p) => {
debug_assert!(
kit.static_packer.downcast_ref::<PackedFormat>().is_some_and(|pf| pf.dt == *p),
"Static packer not compatible with weight format {kit:?}"
)
}
WeightType::BlockQuant(bq) => debug_assert!(
kit.static_packer
.downcast_ref::<PackedBlockQuantFormat>()
.is_some_and(|pbqf| pbqf.bq.same_as(&**bq)),
"Static packer not compatible with weight format {kit:?}"
),
};
kit
}

pub(crate) fn with_native(mut self, mmm: Box<dyn MatMatMul>, packing: usize) -> Self {
assert!(mmm.packings()[packing].0.same_as(&*self.static_packer));
assert!(self.accumulator == mmm.internal_type().into());
debug_assert!(
mmm.packings()[packing].0.same_as(&*self.static_packer),
"Weight packing mismatch {self:?} {mmm:?}/{packing} {:?}",
mmm.packings()[packing].0
);
debug_assert!(
self.accumulator == mmm.internal_type().into(),
"Accumulator mismatch {self:?} {mmm:?}/{packing} {:?}",
mmm.packings()[packing].0
);
self.items.push(MMMKitItem { mmm, packing, weight_panel_extractor: None });
self
}
Expand All @@ -121,7 +147,21 @@ impl MMMKit {
packing: usize,
weight_panel_extractor: PanelExtractor,
) -> Self {
assert!(self.accumulator == mmm.internal_type().into());
debug_assert!(
self.accumulator == mmm.internal_type().into(),
"Accumulator mismatch {self:?} {mmm:?}/{packing} {:?}",
mmm.packings()[packing].0
);
debug_assert!(
self.static_packer.same_as(&*weight_panel_extractor.from),
"Static weight packing/extractor mismatch {self:?} {mmm:?}/{packing} {:?} {weight_panel_extractor:?}",
mmm.packings()[packing].0
);
debug_assert!(
weight_panel_extractor.to.same_as(&*mmm.packings()[packing].0),
"Extractor/kernel packing mismatch {self:?} {mmm:?}/{packing} {:?} {weight_panel_extractor:?}",
mmm.packings()[packing].0
);
self.items.push(MMMKitItem {
mmm,
packing,
Expand Down
12 changes: 7 additions & 5 deletions linalg/src/x86_64_fma/mmm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ pub fn plug(ops: &mut Ops) {
.with_native(fma_mmm_f32_32x1.mmm(), 2)
.with_extracting(fma_mmm_f32_32x3.mmm(), 1, packed_32_q40_to_f32.clone()),
);
ops.mmm_kits.push(MMMKit::new(F16, F32, F16, &PQ40_R32).with_extracting(
fma_mmm_f32_32x3.mmm(),
1,
packed_32_f16_to_f32.clone(),
));
ops.mmm_kits.push(
MMMKit::new(F16, F32, F16, &PackedFormat::new(F16, 32, 32)).with_extracting(
fma_mmm_f32_32x3.mmm(),
1,
packed_32_f16_to_f32.clone(),
),
);
}
}

0 comments on commit 576134a

Please sign in to comment.