diff --git a/metal/src/ops/sync.rs b/metal/src/ops/sync.rs index 8da17aae3b..35f3a61aa0 100644 --- a/metal/src/ops/sync.rs +++ b/metal/src/ops/sync.rs @@ -1,8 +1,8 @@ use crate::fact::MetalTypedFactExt; pub use crate::kernels::BinOps; -use crate::tensor::MetalTensorExt; use crate::ops::MetalEvalOp; -use crate::{ MetalFact, IntoMetal, MetalContext }; +use crate::tensor::MetalTensorExt; +use crate::{IntoMetal, MetalContext, MetalFact}; use derive_new::new; use std::fmt; use tract_core::internal::*; diff --git a/metal/src/rewrite_rules/mod.rs b/metal/src/rewrite_rules/mod.rs index 3e0997808d..73e7d05ec2 100644 --- a/metal/src/rewrite_rules/mod.rs +++ b/metal/src/rewrite_rules/mod.rs @@ -12,7 +12,7 @@ use tract_core::ops::konst::Const; pub use apply_rope::{as_apply_rope_rule, BasicApplyRope}; pub use fuse_axis_op::fuse_axis_op; pub use new_gelu::{as_new_gelu_rule, BasicNewGelu}; -pub use rewire_metal_sync::rewire_metal_sync; +pub use rewire_metal_sync::{rewire_metal_sync, rewire_metal_sync_after_const}; pub use rms_norm::{as_rms_norm_rule, remove_rms_norm_cast, BasicRmsNorm}; pub use rotate_half::{as_rotate_half_rule, BasicRotateHalf}; pub use silu::{as_silu_rule, BasicSilu}; diff --git a/metal/src/rewrite_rules/rewire_metal_sync.rs b/metal/src/rewrite_rules/rewire_metal_sync.rs index 7547b79370..f53c4a92ab 100644 --- a/metal/src/rewrite_rules/rewire_metal_sync.rs +++ b/metal/src/rewrite_rules/rewire_metal_sync.rs @@ -1,7 +1,9 @@ use crate::ops::{MetalSync, MetalSyncKind}; -use crate::rewrite_rules::previous_node; +use crate::rewrite_rules::{next_node, previous_node}; use crate::rule_ensure; +use crate::tensor::MetalTensorExt; use tract_core::internal::*; +use tract_core::ops::konst::Const; pub fn rewire_metal_sync( _ctx: &(), @@ -25,3 +27,29 @@ pub fn rewire_metal_sync( })?; Ok(Some(patch)) } + +pub fn rewire_metal_sync_after_const( + _ctx: &(), + model: &TypedModel, + node: &TypedNode, + node_name: &str, + op: &Const, +) -> TractResult> { + // Search pattern => Const => ToCPU + + let Some(gpu_const) = op.0.as_metal_tensor() else { return Ok(None) }; + let cpu_const = gpu_const.to_cpu()?; + + // Identify precessor ToCpu + let Some(sync_cpu) = next_node(model, node) else { return Ok(None) }; + let Some(sync_cpu_op) = sync_cpu.op_as::() else { return Ok(None) }; + rule_ensure!(sync_cpu_op.kind == MetalSyncKind::ToCpu); + + let mut patch = TypedModelPatch::default(); + + let konst_input = patch.taps(model, &node.inputs)?; + let out = + patch.wire_node(format!("{node_name}"), Const(cpu_const.into(), None), &konst_input)?; + patch.shunt_outside(model, sync_cpu.id.into(), out[0])?; + Ok(Some(patch)) +} diff --git a/metal/src/transform.rs b/metal/src/transform.rs index ba990475c3..e98e23b15b 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -7,8 +7,8 @@ use crate::ops::{self, MetalAxisOp, MetalSync, MetalSyncKind}; #[allow(unused_imports)] use crate::rewrite_rules::{ as_apply_rope_rule, as_new_gelu_rule, as_rms_norm_rule, as_rotate_half_rule, as_silu_rule, - fuse_axis_op, remove_rms_norm_cast, rewire_metal_sync, BasicApplyRope, BasicNewGelu, - BasicRmsNorm, BasicRotateHalf, BasicSilu, + fuse_axis_op, remove_rms_norm_cast, rewire_metal_sync, rewire_metal_sync_after_const, + BasicApplyRope, BasicNewGelu, BasicRmsNorm, BasicRotateHalf, BasicSilu, }; use crate::tensor::MetalTensorExt; use crate::{IntoMetal, MetalFact, MetalTensor}; @@ -70,6 +70,7 @@ impl ModelTransform for MetalTransform { Rewriter::default() .with_rule_for::("rewire-metal-sync", rewire_metal_sync) + .with_rule_for::("rewire-metal-sync-after-const", rewire_metal_sync_after_const) .with_rule_for::("fuse_axis_op", fuse_axis_op) .rewrite(&(), &mut new)?; *model = new; @@ -365,7 +366,6 @@ fn convert_const(op: &Const) -> TractResult> { Ok(Some(Const::new_with_opaque_fact(metal_const, Box::new(metal_fact)))) } - fn convert_element_wise_ops_to_metal(op: &ElementWiseOp) -> Option { map_element_wise_ops!([ (tract_core::ops::math::Abs, Abs),