Skip to content

Commit

Permalink
Remove MetalSyncToCPU between Const and CPU Operator
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos committed Dec 9, 2024
1 parent 47076fe commit 752b9d5
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 7 deletions.
4 changes: 2 additions & 2 deletions metal/src/ops/sync.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::fact::MetalTypedFactExt;
pub use crate::kernels::BinOps;
use crate::tensor::MetalTensorExt;
use crate::ops::MetalEvalOp;
use crate::{ MetalFact, IntoMetal, MetalContext };
use crate::tensor::MetalTensorExt;
use crate::{IntoMetal, MetalContext, MetalFact};
use derive_new::new;
use std::fmt;
use tract_core::internal::*;
Expand Down
2 changes: 1 addition & 1 deletion metal/src/rewrite_rules/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use tract_core::ops::konst::Const;
pub use apply_rope::{as_apply_rope_rule, BasicApplyRope};
pub use fuse_axis_op::fuse_axis_op;
pub use new_gelu::{as_new_gelu_rule, BasicNewGelu};
pub use rewire_metal_sync::rewire_metal_sync;
pub use rewire_metal_sync::{rewire_metal_sync, rewire_metal_sync_after_const};
pub use rms_norm::{as_rms_norm_rule, remove_rms_norm_cast, BasicRmsNorm};
pub use rotate_half::{as_rotate_half_rule, BasicRotateHalf};
pub use silu::{as_silu_rule, BasicSilu};
Expand Down
30 changes: 29 additions & 1 deletion metal/src/rewrite_rules/rewire_metal_sync.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::ops::{MetalSync, MetalSyncKind};
use crate::rewrite_rules::previous_node;
use crate::rewrite_rules::{next_node, previous_node};
use crate::rule_ensure;
use crate::tensor::MetalTensorExt;
use tract_core::internal::*;
use tract_core::ops::konst::Const;

pub fn rewire_metal_sync(
_ctx: &(),
Expand All @@ -25,3 +27,29 @@ pub fn rewire_metal_sync(
})?;
Ok(Some(patch))
}

pub fn rewire_metal_sync_after_const(
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
node_name: &str,
op: &Const,
) -> TractResult<Option<TypedModelPatch>> {
// Search pattern => Const => ToCPU

let Some(gpu_const) = op.0.as_metal_tensor() else { return Ok(None) };
let cpu_const = gpu_const.to_cpu()?;

// Identify precessor ToCpu
let Some(sync_cpu) = next_node(model, node) else { return Ok(None) };
let Some(sync_cpu_op) = sync_cpu.op_as::<MetalSync>() else { return Ok(None) };
rule_ensure!(sync_cpu_op.kind == MetalSyncKind::ToCpu);

let mut patch = TypedModelPatch::default();

let konst_input = patch.taps(model, &node.inputs)?;
let out =
patch.wire_node(format!("{node_name}"), Const(cpu_const.into(), None), &konst_input)?;
patch.shunt_outside(model, sync_cpu.id.into(), out[0])?;
Ok(Some(patch))
}
6 changes: 3 additions & 3 deletions metal/src/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use crate::ops::{self, MetalAxisOp, MetalSync, MetalSyncKind};
#[allow(unused_imports)]
use crate::rewrite_rules::{
as_apply_rope_rule, as_new_gelu_rule, as_rms_norm_rule, as_rotate_half_rule, as_silu_rule,
fuse_axis_op, remove_rms_norm_cast, rewire_metal_sync, BasicApplyRope, BasicNewGelu,
BasicRmsNorm, BasicRotateHalf, BasicSilu,
fuse_axis_op, remove_rms_norm_cast, rewire_metal_sync, rewire_metal_sync_after_const,
BasicApplyRope, BasicNewGelu, BasicRmsNorm, BasicRotateHalf, BasicSilu,
};
use crate::tensor::MetalTensorExt;
use crate::{IntoMetal, MetalFact, MetalTensor};
Expand Down Expand Up @@ -70,6 +70,7 @@ impl ModelTransform for MetalTransform {

Rewriter::default()
.with_rule_for::<MetalSync>("rewire-metal-sync", rewire_metal_sync)
.with_rule_for::<Const>("rewire-metal-sync-after-const", rewire_metal_sync_after_const)
.with_rule_for::<MetalAxisOp>("fuse_axis_op", fuse_axis_op)
.rewrite(&(), &mut new)?;
*model = new;
Expand Down Expand Up @@ -365,7 +366,6 @@ fn convert_const(op: &Const) -> TractResult<Option<Const>> {
Ok(Some(Const::new_with_opaque_fact(metal_const, Box::new(metal_fact))))
}


fn convert_element_wise_ops_to_metal(op: &ElementWiseOp) -> Option<ops::MetalElementWiseOp> {
map_element_wise_ops!([
(tract_core::ops::math::Abs, Abs),
Expand Down

0 comments on commit 752b9d5

Please sign in to comment.