Skip to content

Commit

Permalink
Reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos committed Dec 12, 2024
1 parent 7e528ad commit e7f2f86
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
11 changes: 5 additions & 6 deletions metal/src/rewrite_rules/untranspose_matmul_output.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use tract_core::ops::einsum::BasicMatMul;
use crate::rule_ensure;
use tract_core::internal::*;

use tract_core::ops::einsum::BasicMatMul;

/// Rewrite BasicMatMul { .. transpose_c: true } to BasicMatMul { .. transpose_c: false}
pub fn untranspose_matmul_output(
Expand All @@ -13,13 +12,13 @@ pub fn untranspose_matmul_output(
) -> TractResult<Option<TypedModelPatch>> {
rule_ensure!(op.transpose_c);


let new_matmul = BasicMatMul {
transpose_a: !op.transpose_b,
transpose_b: !op.transpose_a,
transpose_c: false,
.. *op
..*op
};

TypedModelPatch::replace_single_op(model, node, &[node.inputs[1], node.inputs[0]], new_matmul).map(Some)
}
TypedModelPatch::replace_single_op(model, node, &[node.inputs[1], node.inputs[0]], new_matmul)
.map(Some)
}
17 changes: 11 additions & 6 deletions metal/src/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use crate::ops::{self, MetalSync, MetalSyncKind};

use crate::rewrite_rules;
use crate::rewrite_rules::{
BasicApplyRope, BasicNewGelu, BasicRmsNorm,
BasicRotateHalf, BasicScaledMaskedSoftmax, BasicSilu,
BasicApplyRope, BasicNewGelu, BasicRmsNorm, BasicRotateHalf, BasicScaledMaskedSoftmax,
BasicSilu,
};
use crate::tensor::MetalTensorExt;
use crate::{IntoMetal, MetalFact, MetalTensor};
Expand Down Expand Up @@ -62,7 +62,11 @@ impl ModelTransform for MetalTransform {
}

impl MetalTransform {
pub fn transform_up_to_phase(&self, model: &mut TypedModel, stop_at_phase: usize) -> TractResult<()> {
pub fn transform_up_to_phase(
&self,
model: &mut TypedModel,
stop_at_phase: usize,
) -> TractResult<()> {
rewrite_einsums_as_matmul(model)?;
if stop_at_phase == 0 {
return Ok(());
Expand Down Expand Up @@ -91,7 +95,10 @@ impl MetalTransform {

Rewriter::default()
.with_rule_for("rewire-metal-sync", rewrite_rules::rewire_metal_sync)
.with_rule_for("rewire-metal-sync-after-const", rewrite_rules::rewire_metal_sync_after_const)
.with_rule_for(
"rewire-metal-sync-after-const",
rewrite_rules::rewire_metal_sync_after_const,
)
.with_rule_for("fuse_axis_op", rewrite_rules::fuse_axis_op)
.rewrite(&(), model)?;
Ok(())
Expand Down Expand Up @@ -186,8 +193,6 @@ impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>> for Met
target: &mut TypedModel,
mapping: &HashMap<OutletId, OutletId>,
) -> TractResult<TVec<OutletId>> {


let in_dts_metal_compatible = source
.node_input_facts(node.id)?
.iter()
Expand Down

0 comments on commit e7f2f86

Please sign in to comment.