Skip to content

Commit

Permalink
add opaque fact in packed tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Oct 15, 2024
1 parent 062722f commit 0835c29
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 17 deletions.
4 changes: 2 additions & 2 deletions cli/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use fs_err as fs;
use ndarray_npy::NpzWriter;
use nu_ansi_term::Color::*;
use tract_core::ops::cnn::conv::Im2Col;
use tract_core::ops::matmul::pack::MatMatMulPack;
use tract_core::ops::matmul::pack::OptMatMulPack;
use tract_core::tract_data::itertools::izip;
use tract_hir::internal::*;
use tract_libcli::tensor::RunParams;
Expand Down Expand Up @@ -213,7 +213,7 @@ fn run_regular(
}
if assert_sane_floats {
for (ix, o) in clarified_r.iter().enumerate() {
if node.op_is::<Im2Col>() || node.op_is::<MatMatMulPack>() {
if node.op_is::<Im2Col>() || node.op_is::<OptMatMulPack>() {
continue;
}
if let Ok(floats) = o.as_slice::<f32>() {
Expand Down
4 changes: 2 additions & 2 deletions core/src/ops/cnn/conv/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::ops::math::{add, div, mul, sub};
use crate::ops::math::{Add, Div, Mul, Sub};
use crate::ops::matmul::optimized::AddMatMulGeometry;
use crate::ops::matmul::optimized::MapOutputAxisToInput;
use crate::ops::matmul::pack::MatMatMulPack;
use crate::ops::matmul::pack::OptMatMulPack;
use crate::ops::matmul::quant::wire_ensure_q8_flavour;
use crate::ops::nn::Reduce;

Expand Down Expand Up @@ -79,7 +79,7 @@ impl Conv {
) -> TractResult<OutletId> {
Ok(model.wire_node(
format!("{name}.prep_kernel.pack"),
MatMatMulPack { packers: vec![format], k_axis: 2, mn_axis: 1 },
OptMatMulPack { packers: vec![format], k_axis: 2, mn_axis: 1 },
&[kernel],
)?[0])
}
Expand Down
12 changes: 6 additions & 6 deletions core/src/ops/einsum/kernel_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tract_linalg::mmm::{MMMInputValue, MatMatMul};

use crate::internal::*;
use crate::ops::matmul::de_block_quant::{BlockQuantFact, BlockQuantValue};
use crate::ops::matmul::pack::MatMatMulPack;
use crate::ops::matmul::pack::OptMatMulPack;

use super::optimize::EinSumAnnotatedAsMatMul;

Expand Down Expand Up @@ -47,13 +47,13 @@ pub fn wire_packing(
.with_context(|| format!("No packing for {mmm:?} with inputs {a_dt:?} and {b_dt:?}"))?;
let pa = patch.wire_node(
format!("{prefix}.pack_a"),
MatMatMulPack { k_axis: op.a_k(), mn_axis: op.a_m(), packers: vec![pa.clone()] },
OptMatMulPack { k_axis: op.a_k(), mn_axis: op.a_m(), packers: vec![pa.clone()] },
&[operands[0]],
)?[0];

let pb = patch.wire_node(
format!("{prefix}.pack_b"),
MatMatMulPack { k_axis: op.b_k(), mn_axis: op.b_n(), packers: vec![pb.clone()] },
OptMatMulPack { k_axis: op.b_k(), mn_axis: op.b_n(), packers: vec![pb.clone()] },
&[operands[1]],
)?[0];

Expand Down Expand Up @@ -99,7 +99,7 @@ fn with_block_quant(
.clone();
patch
.node_mut(pb.node)
.op_as_mut::<MatMatMulPack>()
.op_as_mut::<OptMatMulPack>()
.context("Expected MatMatMulPack on B")?
.packers
.push(alternative_b_packing);
Expand Down Expand Up @@ -149,7 +149,7 @@ fn with_block_quant_matmat(

let pb = patch.wire_node(
format!("{prefix}.pack_b"),
MatMatMulPack { k_axis: op.b_k(), mn_axis: op.b_n(), packers: vec![pb.clone()] },
OptMatMulPack { k_axis: op.b_k(), mn_axis: op.b_n(), packers: vec![pb.clone()] },
&[operands[1]],
)?[0];

Expand Down Expand Up @@ -198,7 +198,7 @@ fn with_block_quant_matvec(

let pb = patch.wire_node(
format!("{prefix}.pack_b"),
MatMatMulPack { k_axis: op.b_k(), mn_axis: op.b_n(), packers: vec![pb.clone()] },
OptMatMulPack { k_axis: op.b_k(), mn_axis: op.b_n(), packers: vec![pb.clone()] },
&[operands[1]],
)?[0];

Expand Down
32 changes: 25 additions & 7 deletions core/src/ops/matmul/pack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ use tract_data::TooEarly;
use tract_linalg::frame::PackedFormat;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct MatMatMulPack {
pub struct OptMatMulPack {
pub(crate) packers: Vec<PackedFormat>,
pub(crate) k_axis: usize,
pub(crate) mn_axis: usize,
}

impl Op for MatMatMulPack {
impl Op for OptMatMulPack {
fn name(&self) -> Cow<str> {
"MatMatMulPack".into()
"OptMatMulPack".into()
}

fn info(&self) -> TractResult<Vec<String>> {
Expand All @@ -24,7 +24,7 @@ impl Op for MatMatMulPack {
impl_op_same_as!();
}

impl EvalOp for MatMatMulPack {
impl EvalOp for OptMatMulPack {
fn is_stateless(&self) -> bool {
true
}
Expand All @@ -38,9 +38,14 @@ impl EvalOp for MatMatMulPack {
}
}

impl TypedOp for MatMatMulPack {
impl TypedOp for OptMatMulPack {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
Ok(tvec!(Opaque::datum_type().fact(self.output_shape(&inputs[0].shape))))
let k = inputs[0].shape[self.k_axis].clone();
let mn = inputs[0].shape[self.mn_axis].clone();
let opaque_fact = PackedOpaqueFact { k, mn, packers: self.packers.clone() };
Ok(tvec!(Opaque::datum_type()
.fact(self.output_shape(&inputs[0].shape))
.with_opaque_fact(opaque_fact)))
}

fn axes_mapping(
Expand All @@ -63,7 +68,7 @@ impl TypedOp for MatMatMulPack {
as_op!();
}

impl MatMatMulPack {
impl OptMatMulPack {
fn do_eval(&self, session: &SessionState, input: TValue) -> TractResult<TVec<TValue>> {
unsafe {
let packer = if self.packers.len() == 1 {
Expand Down Expand Up @@ -118,3 +123,16 @@ impl MatMatMulPack {
packed_shape
}
}

#[derive(Hash, Clone, Debug, PartialEq, Eq)]
pub struct PackedOpaqueFact {
pub k: TDim,
pub mn: TDim,
pub packers: Vec<PackedFormat>,
}

impl OpaqueFact for PackedOpaqueFact {
fn mem_size(&self) -> TDim {
self.k.clone() * &self.mn * self.packers[0].dt.size_of()
}
}

0 comments on commit 0835c29

Please sign in to comment.