Skip to content

Commit

Permalink
Replace MetalConst -> Const
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos authored and kali committed Dec 6, 2024
1 parent 3df1f74 commit 024cffc
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 49 deletions.
2 changes: 1 addition & 1 deletion metal/src/fact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl MetalFact {
}

pub fn from_cpu(fact: TypedFact) -> TractResult<Self> {
Self::new(MetalOrigin::FromGpu, fact)
Self::new(MetalOrigin::FromCpu, fact)
}

pub fn is_from_gpu(&self) -> bool {
Expand Down
45 changes: 0 additions & 45 deletions metal/src/ops/konst.rs

This file was deleted.

2 changes: 0 additions & 2 deletions metal/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ pub mod change_axes;
pub mod concat;
pub mod element_wise;
pub mod gemm;
pub mod konst;
pub mod new_gelu;
pub mod reduce;
pub mod rms_norm;
Expand All @@ -24,7 +23,6 @@ pub use change_axes::{MetalAxisOp, MetalIntoShape};
pub use concat::MetalConcat;
pub use element_wise::MetalElementWiseOp;
pub use gemm::MetalGemm;
pub use konst::MetalConst;
pub use new_gelu::MetalNewGelu;
pub use reduce::MetalReduce;
pub use rms_norm::MetalRmsNorm;
Expand Down
12 changes: 11 additions & 1 deletion metal/src/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>> for Met
} else if let Some(op) = node.op_as::<MultiBroadcastTo>() {
Some(Box::new(ops::MetalMultiBroadcastTo::new(op.shape.clone())))
} else if let Some(op) = node.op_as::<Const>() {
ops::MetalConst::new(op.0.clone())?.map(|o| -> Box<dyn TypedOp> { Box::new(o) })
convert_const(op)?.map(|o| -> Box<dyn TypedOp> { Box::new(o) })
} else if let Some(op) = node.op_as::<Cast>() {
check_in_dts_are_supported(source, node.id, ops::MetalCast::is_supported_dt)?
.then(|| ops::MetalCast::new(op.to))
Expand Down Expand Up @@ -355,6 +355,16 @@ pub fn bin_ops_to_metal(
.transpose()
}

fn convert_const(op: &Const) -> TractResult<Option<Const>> {
if !MetalTensor::is_supported_dt(op.0.datum_type()) {
return Ok(None);
}
let metal_fact = MetalFact::from_cpu(Arc::clone(&op.0).into())?;
let metal_const = op.0.clone().into_metal()?.into_opaque_tensor().into_arc_tensor();
Ok(Some(Const::new_with_opaque_fact(metal_const, Box::new(metal_fact))))
}


fn convert_element_wise_ops_to_metal(op: &ElementWiseOp) -> Option<ops::MetalElementWiseOp> {
map_element_wise_ops!([
(tract_core::ops::math::Abs, Abs),
Expand Down

0 comments on commit 024cffc

Please sign in to comment.