Skip to content

Commit

Permalink
make q40f16 work on intel
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Oct 22, 2024
1 parent 9a07f0c commit 285d198
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 20 deletions.
93 changes: 74 additions & 19 deletions linalg/src/frame/mmm/panel_extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,58 @@ pub mod test {
}
assert!(extractor.from.r() == extractor.to.r());
assert!(extractor.to.dt == f32::datum_type() || extractor.to.dt == f16::datum_type());
let from = extractor.from.downcast_ref::<PackedBlockQuantFormat>().unwrap();
if let Some(from) = extractor.from.downcast_ref::<PackedBlockQuantFormat>() {
test_packing_bq(extractor, from, blocks, panels)
} else if let Some(from) = extractor.from.downcast_ref() {
test_packing_plain(extractor, &from, blocks, panels)
} else {
todo!()
}
}

pub fn test_packing_plain(
extractor: &PanelExtractor,
from: &PackedFormat,
blocks: usize,
panels: usize,
) -> TractResult<()> {
let m = from.r * panels;
let k = 8 * blocks; // 8 is arbitrary
let to = &extractor.to;
let weights_orig =
Array2::from_shape_fn((m, k), |(m, k)| ((m * 31 + k * 17) % 20) as f32 - 10.)
.into_tensor()
.cast_to_dt(from.dt)?
.into_owned();
let packed_orig = from.prepare_tensor(&weights_orig, 1, 0)?;
let packed_orig = packed_orig.downcast_ref::<EagerPackedInput>().unwrap();

for panel in 0..panels {
let orig_panel =
&packed_orig.packed[packed_orig.panel_bytes * panel..][..k * from.r * from.dt.size_of()];
let mut reference_panel = Tensor::zero_dt(from.dt, &[k, from.r])?;
reference_panel.as_bytes_mut().copy_from_slice(&orig_panel);
reference_panel = reference_panel.cast_to_dt(to.dt)?.into_owned();

let mut tested_panel = Tensor::zero_dt(to.dt, &[k, from.r])?;
unsafe {
(extractor.kernel)(
orig_panel.as_ptr(),
tested_panel.as_bytes_mut().as_mut_ptr(),
k,
);
}
compare_panels(&tested_panel, &reference_panel, from.r, k);
}
Ok(())
}

pub fn test_packing_bq(
extractor: &PanelExtractor,
from: &PackedBlockQuantFormat,
blocks: usize,
panels: usize,
) -> TractResult<()> {
let m = from.r * panels;
let k = from.bq.block_len() * blocks;
let to = &extractor.to;
Expand Down Expand Up @@ -185,26 +236,30 @@ pub mod test {
let source =
packed_block_quant.packed.as_ptr().add(packed_block_quant.panel_bytes * panel);
(extractor.kernel)(source, tested_panel.as_bytes_mut().as_mut_ptr(), k);
if tested_panel != reference_panel {
if to.dt == f32::datum_type() {
crate::frame::mmm::tests::display_error(
tested_panel.as_slice::<f32>().unwrap(),
reference_panel.as_slice::<f32>().unwrap(),
from.r,
k,
);
} else {
crate::frame::mmm::tests::display_error(
tested_panel.as_slice::<f16>().unwrap(),
reference_panel.as_slice::<f16>().unwrap(),
from.r,
k,
);
}
}
assert_eq!(tested_panel, reference_panel);
}
compare_panels(&tested_panel, &reference_panel, from.r, k);
}
Ok(())
}

fn compare_panels(tested_panel: &Tensor, reference_panel: &Tensor, r: usize, k: usize) {
if tested_panel != reference_panel {
if reference_panel.datum_type() == f32::datum_type() {
crate::frame::mmm::tests::display_error(
tested_panel.as_slice::<f32>().unwrap(),
reference_panel.as_slice::<f32>().unwrap(),
r,
k,
);
} else {
crate::frame::mmm::tests::display_error(
tested_panel.as_slice::<f16>().unwrap(),
reference_panel.as_slice::<f16>().unwrap(),
r,
k,
);
}
}
assert_eq!(tested_panel, reference_panel);
}
}
6 changes: 6 additions & 0 deletions linalg/src/x86_64_fma/mmm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::frame::PackedFormat;
use crate::mmm::MMMKit;
use crate::mmm::MatMatMulKer;
use crate::Ops;
use panel_extract::packed_32_f16_to_f32;
use panel_extract::packed_32_q40_to_f32;
use tract_data::internal::*;
use DatumType::*;
Expand Down Expand Up @@ -59,5 +60,10 @@ pub fn plug(ops: &mut Ops) {
.with_native(fma_mmm_f32_32x1.mmm(), 2)
.with_extracting(fma_mmm_f32_32x3.mmm(), 1, packed_32_q40_to_f32.clone()),
);
ops.mmm_kits.push(MMMKit::new(F16, F32, F16, &PQ40_R32).with_extracting(
fma_mmm_f32_32x3.mmm(),
1,
packed_32_f16_to_f32.clone(),
));
}
}
40 changes: 39 additions & 1 deletion linalg/src/x86_64_fma/panel_extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@ use crate::Ops;
use tract_data::internal::*;

pub fn plug(ops: &mut Ops) {
ops.panel_extractors.push(packed_32_q40_to_f32.clone())
ops.panel_extractors.extend([packed_32_q40_to_f32.clone(), packed_32_f16_to_f32.clone()]);
}

panel_extractor!(kernel_packed_32_q40_to_f32 as packed_32_q40_to_f32(
Box::new(super::mmm::PQ40_R32),
PackedFormat::new(f32::datum_type(), 32, 32)
) where(AVX2));

panel_extractor!(kernel_packed_32_f16_to_f32 as packed_32_f16_to_f32(
Box::new(PackedFormat::new(f16::datum_type(), 32, 32)),
PackedFormat::new(f32::datum_type(), 32, 32)
) where(AVX2));

#[target_feature(enable = "avx2")]
unsafe fn kernel_packed_32_q40_to_f32(input: *const u8, output: *mut u8, k: usize) {
debug_assert!(k % 32 == 0);
Expand Down Expand Up @@ -86,3 +91,36 @@ unsafe fn kernel_packed_32_q40_to_f32(input: *const u8, output: *mut u8, k: usiz
out("ymm12") _, out("ymm13") _, out("ymm14") _, out("ymm15") _
);
}

#[target_feature(enable = "avx2")]
unsafe fn kernel_packed_32_f16_to_f32(input: *const u8, output: *mut u8, k: usize) {
debug_assert!(output as usize % 32 == 0);
std::arch::asm!("
2:
vmovaps xmm4, [{i}]
vmovaps xmm5, [{i} + 16]
vmovaps xmm6, [{i} + 32]
vmovaps xmm7, [{i} + 48]
vcvtph2ps ymm4, xmm4
vcvtph2ps ymm5, xmm5
vcvtph2ps ymm6, xmm6
vcvtph2ps ymm7, xmm7
vmovaps [{o}], ymm4
vmovaps [{o}+32], ymm5
vmovaps [{o}+64], ymm6
vmovaps [{o}+96], ymm7
add {i}, 64
add {o}, 128
sub {k}, 1
jnz 2b;
",
k = inout(reg) k => _,
i = inout(reg) input => _,
o = inout(reg) output => _,
out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _,
);
}

0 comments on commit 285d198

Please sign in to comment.