From 8ed325a23b6fa64052b79e7032311a6aa65b6cc2 Mon Sep 17 00:00:00 2001 From: Hubert de La Jonquiere Date: Tue, 10 Dec 2024 12:34:45 +0100 Subject: [PATCH 1/6] Metal Implementatio of Scaled Masked Softmax --- metal/src/kernels/nn/mod.rs | 8 + metal/src/kernels/nn/nn_ops.metal | 58 ++++ metal/src/kernels/nn/scaled_masked_softmax.rs | 273 ++++++++++++++++++ metal/src/rewrite_rules/mod.rs | 2 + .../rewrite_rules/scaled_masked_softmax.rs | 52 ++++ 5 files changed, 393 insertions(+) create mode 100644 metal/src/kernels/nn/scaled_masked_softmax.rs create mode 100644 metal/src/rewrite_rules/scaled_masked_softmax.rs diff --git a/metal/src/kernels/nn/mod.rs b/metal/src/kernels/nn/mod.rs index 61c170af42..4a1ed6b582 100644 --- a/metal/src/kernels/nn/mod.rs +++ b/metal/src/kernels/nn/mod.rs @@ -2,6 +2,7 @@ pub mod apply_rope; pub mod new_gelu; pub mod reduce; pub mod rms_norm; +pub mod scaled_masked_softmax; pub mod silu; pub mod softmax; @@ -9,6 +10,7 @@ pub use apply_rope::ApplyRope; pub use new_gelu::NewGelu; pub use reduce::Reducer; pub use rms_norm::RmsNorm; +pub use scaled_masked_softmax::ScaledMaskedSoftmax; pub use silu::Silu; pub use softmax::Softmax; @@ -28,5 +30,11 @@ pub fn all_functions() -> Vec { .flat_map(|dt| Softmax.kernel_name(dt).into_iter()), ); + functions.extend( + crate::MetalTensor::SUPPORTED_DT + .into_iter() + .flat_map(|dt| ScaledMaskedSoftmax.kernel_name(dt).into_iter()), + ); + functions.into_iter().collect() } diff --git a/metal/src/kernels/nn/nn_ops.metal b/metal/src/kernels/nn/nn_ops.metal index fcf836b80f..299fce485a 100644 --- a/metal/src/kernels/nn/nn_ops.metal +++ b/metal/src/kernels/nn/nn_ops.metal @@ -276,6 +276,64 @@ typedef decltype(softmax_nd3) softmax_nd3_t; template [[host_name("nn_ops::softmax_nd3_f32")]] [[kernel]] softmax_nd3_t softmax_nd3; template [[host_name("nn_ops::softmax_nd3_f16")]] [[kernel]] softmax_nd3_t softmax_nd3; +template +[[kernel]] void scaled_masked_softmax_nd3( + device const void *input_b, + device const void *mask_b, + constant void *scale_b, + device void *output_b, + constant const size_t shape[3], + constant const size_t strides[3], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint tpsg[[threads_per_simdgroup]] + ) { + + device const F *input = (device const F *)input_b; + device const F *mask = (device const F *)mask_b; + constant F scale = ((constant F *)scale_b)[0]; + device F *output = (device F *)output_b; + + size_t dim = shape[1]; + + size_t base_idx = tgpig.x * strides[2] + + tgpig.z * strides[0]; + + // Get max value on softmax dim after apply + float partial_max = -INFINITY; + for (size_t i = tiisg; i < dim; i += tpsg) { + auto idx = base_idx + i * strides[1]; + float el = static_cast(input[idx] * scale + mask[idx]); + partial_max = max(partial_max, el); + } + + float axis_max = simd_max(partial_max); + + // Compute Sum(exp(x - max)) + float partial_norm = 0; + for (size_t i = tiisg; i < dim; i += tpsg) { + auto idx = base_idx + i * strides[1]; + float el = static_cast(input[idx] * scale + mask[idx]); + float exp_el = fast::exp(el - axis_max); + partial_norm += exp_el; + output[idx] = static_cast(exp_el); + } + + float axis_norm = simd_sum(partial_norm); + float inv_axis_norm = 1.0 / axis_norm; + + for (size_t i = tiisg; i < dim; i += tpsg) { + auto idx = base_idx + i * strides[1]; + float exp_el = static_cast(output[idx]); + output[idx] = static_cast(exp_el * inv_axis_norm); + } +} + +typedef decltype(scaled_masked_softmax_nd3) scaled_masked_softmax_nd3_t; + +template [[host_name("nn_ops::scaled_masked_softmax_nd3_f32")]] [[kernel]] scaled_masked_softmax_nd3_t scaled_masked_softmax_nd3; +template [[host_name("nn_ops::scaled_masked_softmax_nd3_f16")]] [[kernel]] scaled_masked_softmax_nd3_t scaled_masked_softmax_nd3; + constant float GELU_COEF_A = 0.044715f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; diff --git a/metal/src/kernels/nn/scaled_masked_softmax.rs b/metal/src/kernels/nn/scaled_masked_softmax.rs new file mode 100644 index 0000000000..3e6d44bdde --- /dev/null +++ b/metal/src/kernels/nn/scaled_masked_softmax.rs @@ -0,0 +1,273 @@ +use crate::encoder::EncoderExt; +use crate::kernels::utils; +use crate::{LibraryName, MetalContext, MetalTensor}; +use anyhow::Result; +use metal::MTLSize; +use tract_core::internal::*; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ScaledMaskedSoftmax; + +impl ScaledMaskedSoftmax { + pub fn is_supported_dt(dt: DatumType) -> bool { + matches!(dt, DatumType::F32 | DatumType::F16) + } + + pub fn kernel_name(&self, dt: DatumType) -> Result { + ensure!( + Self::is_supported_dt(dt), + "Unsupport dt {:?} for metal scaled masked softmax op", + dt + ); + let tname = MetalTensor::tname(dt)?; + Ok(format!("nn_ops::scaled_masked_softmax_nd3_{tname}")) + } + + pub fn eval( + &self, + context: &MetalContext, + input: &MetalTensor, + axis: usize, + ) -> Result { + let output = unsafe { MetalTensor::uninitialized_dt(input.datum_type(), input.shape())? }; + self.dispatch_eval(context, input, axis, &output)?; + context.wait_until_completed()?; + Ok(output) + } + + pub fn dispatch_eval( + &self, + context: &MetalContext, + input: &MetalTensor, + scale: &Tensor, + mask: &MetalTensor, + axis: usize, + output: &MetalTensor, + ) -> Result<()> { + input.retained_until_completion(); + mask.retained_until_completion(); + output.retained_until_completion(); + + ensure!(output.shape() == input.shape()); + ensure!(input.shape() == mask.shape()); + ensure!(output.datum_type() == input.datum_type()); + + let shape_nd3 = utils::reshape_to_rank_3(input.shape(), axis); + let strides_nd3 = Tensor::natural_strides(&shape_nd3); + + let pipeline = context + .shared_context() + .load_pipeline(LibraryName::NNOps, &self.kernel_name(input.datum_type())?)?; + + let command_buffer = context.command_buffer(); + command_buffer.encode(|encoder| { + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_metal_tensor(0, input, metal::MTLResourceUsage::Read); + encoder.set_metal_tensor(1, mask, metal::MTLResourceUsage::Read); + encoder.set_tensor(2, eps); + encoder.set_metal_tensor(3, output, metal::MTLResourceUsage::Write); + encoder.set_slice(4, &shape_nd3); + encoder.set_slice(5, &strides_nd3); + + let grid_size = + MTLSize { width: shape_nd3[2] as _, height: 1, depth: shape_nd3[0] as _ }; + let group_size = + MTLSize { width: usize::min(32, shape_nd3[1]) as _, height: 1, depth: 1 }; + + encoder.dispatch_thread_groups(grid_size, group_size); + encoder.end_encoding(); + }); + Ok(()) + } +} + +// #[cfg(test)] +// mod tests { +// use super::*; +// use crate::IntoMetal; +// use derive_new::new; +// use num_traits::AsPrimitive; +// use num_traits::Float; +// use proptest::collection::vec; +// use proptest::prelude::*; +// use tract_core::internal::Tensor; +// use tract_core::ops::nn::Softmax as TractSoftmax; +// use tract_core::ops::nn::SoftmaxExp; + +// #[test] +// fn test_softmax_f32() -> Result<()> { +// objc::rc::autoreleasepool(|| { +// crate::METAL_CONTEXT.with_borrow(|context| { +// let m = 4; +// let k = 4; +// let axis = 1; + +// let a = +// Tensor::from_shape(&[m, k], &(0..m * k).map(|f| f as f32).collect::>())? +// .into_metal()?; + +// let cpu_softmax = TractSoftmax { +// axes: tvec![axis], +// quant_output_dt: None, +// exp: SoftmaxExp::Libc, +// }; + +// let cpu_output = +// cpu_softmax.eval(tvec![a.to_cpu()?.into_tvalue()])?[0].clone().into_tensor(); +// let metal_output = Softmax.eval(context, &a, axis)?; +// cpu_output.close_enough(&metal_output.to_cpu()?, Approximation::Approximate)?; +// Ok(()) +// }) +// }) +// } + +// #[test] +// fn test_softmax_f32_2() -> Result<()> { +// objc::rc::autoreleasepool(|| { +// crate::METAL_CONTEXT.with_borrow(|context| { +// let shape = [8, 4, 3]; +// let num_elements = shape.iter().product(); +// let axis = 0; + +// let a = Tensor::from_shape( +// &shape, +// &(0..num_elements).map(|f| f as f32 / 1000.0).collect::>(), +// )? +// .into_metal()?; + +// let cpu_softmax = TractSoftmax { +// axes: tvec![axis], +// quant_output_dt: None, +// exp: SoftmaxExp::Libc, +// }; + +// let cpu_output = +// cpu_softmax.eval(tvec![a.to_cpu()?.into_tvalue()])?[0].clone().into_tensor(); +// let metal_output = Softmax.eval(context, &a, axis)?; +// cpu_output.close_enough(&metal_output.to_cpu()?, Approximation::Approximate)?; +// Ok(()) +// }) +// }) +// } + +// #[test] +// fn test_softmax_f16() -> Result<()> { +// objc::rc::autoreleasepool(|| { +// crate::METAL_CONTEXT.with_borrow(|context| { +// let m = 4; +// let k = 4; +// let axis = 1; + +// let a = Tensor::from_shape( +// &[m, k], +// &(0..m * k).map(|f| -> f16 { f.as_() }).collect::>(), +// )? +// .into_metal()?; + +// let cpu_softmax = TractSoftmax { +// axes: tvec![axis], +// quant_output_dt: None, +// exp: SoftmaxExp::Libc, +// }; + +// let cpu_output = +// cpu_softmax.eval(tvec![a.to_cpu()?.into_tvalue()])?[0].clone().into_tensor(); +// let metal_output = Softmax.eval(context, &a, axis)?; +// cpu_output.close_enough(&metal_output.to_cpu()?, Approximation::Approximate)?; +// Ok(()) +// }) +// }) +// } + +// proptest::proptest! { +// #[test] +// fn softmax_prop_f32(pb in any::>()) { +// fn run(pb: SoftmaxProblem) -> TractResult<()> { +// let out = pb.run()?; +// let reference = pb.reference()?; + +// out.close_enough(&reference, Approximation::Approximate) +// .with_context(|| anyhow!("Cpu: {:?}, Metal: {:?}", reference.dump(true), out.dump(true))) +// } +// run(pb).map_err(|e| TestCaseError::Fail(format!("{:?}", e).into()))?; +// } + +// #[test] +// fn softmax_prop_f16(pb in any::>()) { +// fn run(pb: SoftmaxProblem) -> TractResult<()> { +// let out = pb.run()?; +// let reference = pb.reference()?; + +// out.close_enough(&reference, Approximation::Approximate) +// .with_context(|| anyhow!("Cpu: {:?}, Metal: {:?}", reference.dump(true), out.dump(true))) +// } + +// run(pb).map_err(|e| TestCaseError::Fail(format!("{:?}", e).into()))?; +// } +// } + +// #[derive(Debug, new)] +// pub struct SoftmaxProblem +// where +// F: Datum + Float, +// usize: AsPrimitive, +// { +// pub shape: Vec, +// pub axis: usize, +// pub input: Vec, +// } + +// impl Arbitrary for SoftmaxProblem +// where +// F: Datum + Float, +// usize: AsPrimitive, +// { +// type Parameters = (); +// type Strategy = BoxedStrategy; + +// fn arbitrary_with(_: ()) -> Self::Strategy { +// (0usize..3, 0usize..3) +// .prop_flat_map(|(left, right)| { +// let axis = left; +// let shape_len = usize::min(left + right + 1, 4); +// let shape = 1usize..10; +// (vec(shape, shape_len..=shape_len), Just(axis)) +// }) +// .prop_map(|(shape, axis)| { +// let input = (0..shape.iter().product::()) +// .map(|f| f.as_() / 1000.as_()) +// .collect::>(); +// Self { shape, axis, input } +// }) +// .boxed() +// } +// } + +// impl SoftmaxProblem +// where +// F: Datum + Float + std::ops::AddAssign, +// usize: AsPrimitive, +// { +// pub fn reference(&self) -> Result { +// let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?; + +// let cpu_softmax = TractSoftmax { +// axes: tvec![self.axis], +// quant_output_dt: None, +// exp: SoftmaxExp::Libc, +// }; +// let cpu_output = cpu_softmax.eval(tvec![a.into_tvalue()])?[0].clone().into_tensor(); +// Ok(cpu_output) +// } + +// pub fn run(&self) -> Result { +// objc::rc::autoreleasepool(|| { +// crate::METAL_CONTEXT.with_borrow(|context| { +// let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?.into_metal()?; +// let metal_output = Softmax.eval(context, &a, self.axis)?; +// metal_output.to_cpu() +// }) +// }) +// } +// } +// } diff --git a/metal/src/rewrite_rules/mod.rs b/metal/src/rewrite_rules/mod.rs index 73e7d05ec2..5e7d0b4d5a 100644 --- a/metal/src/rewrite_rules/mod.rs +++ b/metal/src/rewrite_rules/mod.rs @@ -4,6 +4,7 @@ mod new_gelu; mod rewire_metal_sync; mod rms_norm; mod rotate_half; +mod scaled_masked_softmax; mod silu; use tract_core::internal::*; @@ -15,6 +16,7 @@ pub use new_gelu::{as_new_gelu_rule, BasicNewGelu}; pub use rewire_metal_sync::{rewire_metal_sync, rewire_metal_sync_after_const}; pub use rms_norm::{as_rms_norm_rule, remove_rms_norm_cast, BasicRmsNorm}; pub use rotate_half::{as_rotate_half_rule, BasicRotateHalf}; +pub use scaled_masked_softmax::BasicScaledMaskedSoftmax; pub use silu::{as_silu_rule, BasicSilu}; use tract_core::ops::binary::TypedBinOp; diff --git a/metal/src/rewrite_rules/scaled_masked_softmax.rs b/metal/src/rewrite_rules/scaled_masked_softmax.rs new file mode 100644 index 0000000000..17127b70bf --- /dev/null +++ b/metal/src/rewrite_rules/scaled_masked_softmax.rs @@ -0,0 +1,52 @@ +use std::sync::Arc; +use tract_core::internal::*; +use tract_core::ops::binary::BinMiniOp; +use tract_core::ops::math::{Add, Mul}; +use tract_core::ops::nn::{Softmax, SoftmaxExp}; + +#[derive(Clone, Debug, Hash)] +pub struct BasicScaledMaskedSoftmax { + pub axis: usize, + pub scale: Arc, +} + +impl Op for BasicScaledMaskedSoftmax { + fn name(&self) -> Cow { + "BasicScaledMaskedSoftmax".to_string().into() + } + fn info(&self) -> TractResult> { + Ok(vec![format!("axis: {:?}, scale: {:?}", self.axis, self.scale)]) + } + op_as_typed_op!(); +} + +impl EvalOp for BasicScaledMaskedSoftmax { + fn is_stateless(&self) -> bool { + true + } + + fn eval(&self, inputs: TVec) -> TractResult> { + let (input, mask) = args_2!(inputs); + let dt = input.datum_type(); + ensure!(input.shape() == mask.shape()); + + let scaled_input = Mul.eval(input, self.scale.clone().into_tvalue(), dt)?; + let masked_input = Add.eval(scaled_input.into(), mask, dt)?; + let softmax = Softmax::new(tvec![self.axis], None, SoftmaxExp::Libc) + .eval(tvec![masked_input.into()])?[0]; + Ok(tvec![softmax.into()]) + } +} + +impl TypedOp for BasicScaledMaskedSoftmax { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + ensure!(inputs.len() == 2); + let (input, mask) = (inputs[0], inputs[1]); + ensure!(input.datum_type == mask.datum_type); + let dt = input.datum_type; + let fact = dt.fact(input.shape.clone()); + Ok(tvec!(fact)) + } + + as_op!(); +} From b13545568cd65ee1ecc6e21850c7d1ba6b762d9e Mon Sep 17 00:00:00 2001 From: Hubert de La Jonquiere Date: Tue, 10 Dec 2024 14:18:39 +0100 Subject: [PATCH 2/6] WIP --- metal/src/kernels/nn/scaled_masked_softmax.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/metal/src/kernels/nn/scaled_masked_softmax.rs b/metal/src/kernels/nn/scaled_masked_softmax.rs index 3e6d44bdde..eba9f47ed3 100644 --- a/metal/src/kernels/nn/scaled_masked_softmax.rs +++ b/metal/src/kernels/nn/scaled_masked_softmax.rs @@ -14,11 +14,7 @@ impl ScaledMaskedSoftmax { } pub fn kernel_name(&self, dt: DatumType) -> Result { - ensure!( - Self::is_supported_dt(dt), - "Unsupport dt {:?} for metal scaled masked softmax op", - dt - ); + ensure!(Self::is_supported_dt(dt), "Unsupport dt {:?} for metal scaled masked softmax op", dt); let tname = MetalTensor::tname(dt)?; Ok(format!("nn_ops::scaled_masked_softmax_nd3_{tname}")) } From dab1e947bfb0d5643b1d6f93b2e9e13d8d0aaa3b Mon Sep 17 00:00:00 2001 From: Hubert de La Jonquiere Date: Tue, 10 Dec 2024 17:39:50 +0100 Subject: [PATCH 3/6] Scaled Masked Softmax integration WIP --- metal/src/kernels/nn/nn_ops.metal | 56 +++++---- metal/src/kernels/nn/scaled_masked_softmax.rs | 112 +++++++++--------- metal/src/ops/mod.rs | 2 + metal/src/ops/scaled_masked_softmax.rs | 57 +++++++++ metal/src/rewrite_rules/fuse_axis_op.rs | 1 + metal/src/rewrite_rules/mod.rs | 2 +- .../rewrite_rules/scaled_masked_softmax.rs | 81 +++++++++++-- metal/src/transform.rs | 16 ++- 8 files changed, 235 insertions(+), 92 deletions(-) create mode 100644 metal/src/ops/scaled_masked_softmax.rs diff --git a/metal/src/kernels/nn/nn_ops.metal b/metal/src/kernels/nn/nn_ops.metal index 299fce485a..c38be59a40 100644 --- a/metal/src/kernels/nn/nn_ops.metal +++ b/metal/src/kernels/nn/nn_ops.metal @@ -284,6 +284,7 @@ template device void *output_b, constant const size_t shape[3], constant const size_t strides[3], + constant const size_t mask_strides[3], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint tpsg[[threads_per_simdgroup]] @@ -291,42 +292,47 @@ template device const F *input = (device const F *)input_b; device const F *mask = (device const F *)mask_b; - constant F scale = ((constant F *)scale_b)[0]; + F scale = ((constant F *)scale_b)[0]; device F *output = (device F *)output_b; - size_t dim = shape[1]; + size_t dim = shape[2]; - size_t base_idx = tgpig.x * strides[2] + size_t base_idx = tgpig.y * strides[1] + tgpig.z * strides[0]; + size_t mask_base_idx = tgpig.y * mask_strides[1] + + tgpig.z * mask_strides[0]; + // Get max value on softmax dim after apply float partial_max = -INFINITY; for (size_t i = tiisg; i < dim; i += tpsg) { - auto idx = base_idx + i * strides[1]; - float el = static_cast(input[idx] * scale + mask[idx]); + auto idx = base_idx + i * strides[2]; + auto mask_idx = mask_base_idx + i * mask_strides[2]; + output[idx] = input[idx] * scale + mask[mask_idx]; + float el = static_cast(output[idx]); partial_max = max(partial_max, el); } - float axis_max = simd_max(partial_max); - - // Compute Sum(exp(x - max)) - float partial_norm = 0; - for (size_t i = tiisg; i < dim; i += tpsg) { - auto idx = base_idx + i * strides[1]; - float el = static_cast(input[idx] * scale + mask[idx]); - float exp_el = fast::exp(el - axis_max); - partial_norm += exp_el; - output[idx] = static_cast(exp_el); - } - - float axis_norm = simd_sum(partial_norm); - float inv_axis_norm = 1.0 / axis_norm; - - for (size_t i = tiisg; i < dim; i += tpsg) { - auto idx = base_idx + i * strides[1]; - float exp_el = static_cast(output[idx]); - output[idx] = static_cast(exp_el * inv_axis_norm); - } + float axis_max = simd_max(partial_max); + + // Compute Sum(exp(x - max)) + float partial_norm = 0; + for (size_t i = tiisg; i < dim; i += tpsg) { + auto idx = base_idx + i * strides[2]; + float el = static_cast(output[idx]); + float exp_el = fast::exp(el - axis_max); + partial_norm += exp_el; + } + + float axis_norm = simd_sum(partial_norm); + float inv_axis_norm = 1.0 / axis_norm; + + for (size_t i = tiisg; i < dim; i += tpsg) { + auto idx = base_idx + i * strides[2]; + float el = static_cast(output[idx]); + float exp_el = fast::exp(el - axis_max); + output[idx] = static_cast(exp_el * inv_axis_norm); + } } typedef decltype(scaled_masked_softmax_nd3) scaled_masked_softmax_nd3_t; diff --git a/metal/src/kernels/nn/scaled_masked_softmax.rs b/metal/src/kernels/nn/scaled_masked_softmax.rs index eba9f47ed3..a22059ba3e 100644 --- a/metal/src/kernels/nn/scaled_masked_softmax.rs +++ b/metal/src/kernels/nn/scaled_masked_softmax.rs @@ -1,5 +1,4 @@ use crate::encoder::EncoderExt; -use crate::kernels::utils; use crate::{LibraryName, MetalContext, MetalTensor}; use anyhow::Result; use metal::MTLSize; @@ -14,7 +13,11 @@ impl ScaledMaskedSoftmax { } pub fn kernel_name(&self, dt: DatumType) -> Result { - ensure!(Self::is_supported_dt(dt), "Unsupport dt {:?} for metal scaled masked softmax op", dt); + ensure!( + Self::is_supported_dt(dt), + "Unsupport dt {:?} for metal scaled masked softmax op", + dt + ); let tname = MetalTensor::tname(dt)?; Ok(format!("nn_ops::scaled_masked_softmax_nd3_{tname}")) } @@ -23,10 +26,12 @@ impl ScaledMaskedSoftmax { &self, context: &MetalContext, input: &MetalTensor, - axis: usize, + scale: &Tensor, + mask: &MetalTensor, ) -> Result { let output = unsafe { MetalTensor::uninitialized_dt(input.datum_type(), input.shape())? }; - self.dispatch_eval(context, input, axis, &output)?; + dbg!(&output); + self.dispatch_eval(context, input, scale, mask, &output)?; context.wait_until_completed()?; Ok(output) } @@ -37,7 +42,6 @@ impl ScaledMaskedSoftmax { input: &MetalTensor, scale: &Tensor, mask: &MetalTensor, - axis: usize, output: &MetalTensor, ) -> Result<()> { input.retained_until_completion(); @@ -45,12 +49,14 @@ impl ScaledMaskedSoftmax { output.retained_until_completion(); ensure!(output.shape() == input.shape()); - ensure!(input.shape() == mask.shape()); + ensure!(mask.rank() == 3 && input.rank() == 3); ensure!(output.datum_type() == input.datum_type()); - let shape_nd3 = utils::reshape_to_rank_3(input.shape(), axis); - let strides_nd3 = Tensor::natural_strides(&shape_nd3); - + let shape = input.shape(); + let strides = input.strides(); + let mask_strides_nd3 = + crate::utils::compute_broadcast_strides::(mask.shape(), mask.strides())?; + let pipeline = context .shared_context() .load_pipeline(LibraryName::NNOps, &self.kernel_name(input.datum_type())?)?; @@ -60,15 +66,14 @@ impl ScaledMaskedSoftmax { encoder.set_compute_pipeline_state(&pipeline); encoder.set_metal_tensor(0, input, metal::MTLResourceUsage::Read); encoder.set_metal_tensor(1, mask, metal::MTLResourceUsage::Read); - encoder.set_tensor(2, eps); + encoder.set_tensor(2, scale); encoder.set_metal_tensor(3, output, metal::MTLResourceUsage::Write); - encoder.set_slice(4, &shape_nd3); - encoder.set_slice(5, &strides_nd3); + encoder.set_slice(4, &shape); + encoder.set_slice(5, &strides); + encoder.set_slice(6, &mask_strides_nd3); - let grid_size = - MTLSize { width: shape_nd3[2] as _, height: 1, depth: shape_nd3[0] as _ }; - let group_size = - MTLSize { width: usize::min(32, shape_nd3[1]) as _, height: 1, depth: 1 }; + let grid_size = MTLSize { width: 1 as _, height: shape[1] as _, depth: shape[0] as _}; + let group_size = MTLSize { width: usize::min(32, shape[2]) as _, height: 1, depth: 1 }; encoder.dispatch_thread_groups(grid_size, group_size); encoder.end_encoding(); @@ -77,45 +82,42 @@ impl ScaledMaskedSoftmax { } } -// #[cfg(test)] -// mod tests { -// use super::*; -// use crate::IntoMetal; -// use derive_new::new; -// use num_traits::AsPrimitive; -// use num_traits::Float; -// use proptest::collection::vec; -// use proptest::prelude::*; -// use tract_core::internal::Tensor; -// use tract_core::ops::nn::Softmax as TractSoftmax; -// use tract_core::ops::nn::SoftmaxExp; - -// #[test] -// fn test_softmax_f32() -> Result<()> { -// objc::rc::autoreleasepool(|| { -// crate::METAL_CONTEXT.with_borrow(|context| { -// let m = 4; -// let k = 4; -// let axis = 1; - -// let a = -// Tensor::from_shape(&[m, k], &(0..m * k).map(|f| f as f32).collect::>())? -// .into_metal()?; - -// let cpu_softmax = TractSoftmax { -// axes: tvec![axis], -// quant_output_dt: None, -// exp: SoftmaxExp::Libc, -// }; - -// let cpu_output = -// cpu_softmax.eval(tvec![a.to_cpu()?.into_tvalue()])?[0].clone().into_tensor(); -// let metal_output = Softmax.eval(context, &a, axis)?; -// cpu_output.close_enough(&metal_output.to_cpu()?, Approximation::Approximate)?; -// Ok(()) -// }) -// }) -// } +#[cfg(test)] +mod tests { + use crate::rewrite_rules::BasicScaledMaskedSoftmax; +use super::*; + use crate::IntoMetal; + use derive_new::new; + + + + + use tract_core::internal::Tensor; + + #[test] + fn test_scaled_masked_softmax_f32() -> Result<()> { + objc::rc::autoreleasepool(|| { + crate::METAL_CONTEXT.with_borrow(|context| { + let m = 4; + let n = 4; + let scale: Arc<_> = tensor0(0.125f32).into(); + let mask = Tensor::from_shape(&[1, m, n], &vec![-1000f32; m*n])?.into_metal()?; + + let a = + Tensor::from_shape(&[1, m, n], &(0..m * n).map(|f| f as f32).collect::>())? + .into_metal()?; + + let cpu = BasicScaledMaskedSoftmax { scale:scale.clone() }; + + let cpu_output = + cpu.eval(tvec![a.to_cpu()?.into_tvalue(), mask.to_cpu()?.into_tvalue()])?[0].clone().into_tensor(); + let metal_output = ScaledMaskedSoftmax.eval(context, &a, &scale, &mask)?; + cpu_output.close_enough(&metal_output.to_cpu()?, Approximation::Approximate)?; + Ok(()) + }) + }) + } +} // #[test] // fn test_softmax_f32_2() -> Result<()> { diff --git a/metal/src/ops/mod.rs b/metal/src/ops/mod.rs index 2fca47f0c6..613554ccb8 100644 --- a/metal/src/ops/mod.rs +++ b/metal/src/ops/mod.rs @@ -11,6 +11,7 @@ pub mod new_gelu; pub mod reduce; pub mod rms_norm; pub mod rotate_half; +pub mod scaled_masked_softmax; pub mod silu; pub mod slice; pub mod softmax; @@ -29,6 +30,7 @@ pub use new_gelu::MetalNewGelu; pub use reduce::MetalReduce; pub use rms_norm::MetalRmsNorm; pub use rotate_half::MetalRotateHalf; +pub use scaled_masked_softmax::MetalScaledMaskedSoftmax; pub use silu::MetalSilu; pub use slice::MetalSlice; pub use softmax::MetalSoftmax; diff --git a/metal/src/ops/scaled_masked_softmax.rs b/metal/src/ops/scaled_masked_softmax.rs new file mode 100644 index 0000000000..4ac66349f5 --- /dev/null +++ b/metal/src/ops/scaled_masked_softmax.rs @@ -0,0 +1,57 @@ +use crate::kernels::nn::ScaledMaskedSoftmax; +use crate::ops::MetalEvalOp; +use crate::tensor::MetalTensorExt; +use crate::MetalContext; +use derive_new::new; +use tract_core::internal::*; + +/// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=1) +/// Only input of rank of 3 is supported and softmax axis = 1 +#[derive(Clone, Debug, new, Hash)] +pub struct MetalScaledMaskedSoftmax { + pub scale: Arc, +} + +impl Op for MetalScaledMaskedSoftmax { + fn name(&self) -> Cow { + "MetalScaledMaskedSoftmax".into() + } + + op_as_typed_op!(); +} + +impl MetalEvalOp for MetalScaledMaskedSoftmax { + fn metal_eval( + &self, + context: &MetalContext, + node_id: usize, + session: &mut SessionState, + inputs: TVec, + ) -> TractResult> { + let (opaque_input, opaque_mask) = args_2!(inputs); + let input = opaque_input.to_metal_tensor()?; + let mask = opaque_mask.to_metal_tensor()?; + let output = + crate::ops::make_tensor_for_node(session, node_id, input.datum_type(), input.shape())?; + ScaledMaskedSoftmax.dispatch_eval(context, input, &self.scale, mask, &output)?; + Ok(tvec!(output.into_opaque_tensor().into_tvalue())) + } +} + +impl TypedOp for MetalScaledMaskedSoftmax { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + crate::utils::metal_facts_from_gpu(inputs, |facts| { + ensure!(facts.len() == 2); + let dt = facts[0].datum_type; + ensure!(dt == facts[1].datum_type); + ensure!(facts[0].rank() == 3 && facts[1].rank() == 3); + let fact = dt.fact(facts[0].shape.clone()); + Ok(tvec!(fact)) + }) + .with_context(|| anyhow::anyhow!("Error while computing facts for {:?}", self.name())) + } + + as_op!(); +} + +crate::impl_eval_op_for_metal_op!(MetalScaledMaskedSoftmax); diff --git a/metal/src/rewrite_rules/fuse_axis_op.rs b/metal/src/rewrite_rules/fuse_axis_op.rs index b68b9facd0..9495390214 100644 --- a/metal/src/rewrite_rules/fuse_axis_op.rs +++ b/metal/src/rewrite_rules/fuse_axis_op.rs @@ -109,6 +109,7 @@ pub fn fuse_axis_op( crate::ops::MetalSlice, crate::ops::MetalConcat, crate::ops::MetalCast, + crate::ops::MetalScaledMaskedSoftmax, ); // Handle AxisOp::Move operator. diff --git a/metal/src/rewrite_rules/mod.rs b/metal/src/rewrite_rules/mod.rs index 5e7d0b4d5a..93c01740dc 100644 --- a/metal/src/rewrite_rules/mod.rs +++ b/metal/src/rewrite_rules/mod.rs @@ -16,7 +16,7 @@ pub use new_gelu::{as_new_gelu_rule, BasicNewGelu}; pub use rewire_metal_sync::{rewire_metal_sync, rewire_metal_sync_after_const}; pub use rms_norm::{as_rms_norm_rule, remove_rms_norm_cast, BasicRmsNorm}; pub use rotate_half::{as_rotate_half_rule, BasicRotateHalf}; -pub use scaled_masked_softmax::BasicScaledMaskedSoftmax; +pub use scaled_masked_softmax::{as_scaled_masked_softmax_rule, BasicScaledMaskedSoftmax}; pub use silu::{as_silu_rule, BasicSilu}; use tract_core::ops::binary::TypedBinOp; diff --git a/metal/src/rewrite_rules/scaled_masked_softmax.rs b/metal/src/rewrite_rules/scaled_masked_softmax.rs index 17127b70bf..45244f206b 100644 --- a/metal/src/rewrite_rules/scaled_masked_softmax.rs +++ b/metal/src/rewrite_rules/scaled_masked_softmax.rs @@ -1,12 +1,17 @@ +use crate::rewrite_rules::{collect_node_const_inputs, previous_node, previous_nodes}; +use crate::rule_ensure; +use tract_core::ops::binary::TypedBinOp; + use std::sync::Arc; use tract_core::internal::*; use tract_core::ops::binary::BinMiniOp; use tract_core::ops::math::{Add, Mul}; use tract_core::ops::nn::{Softmax, SoftmaxExp}; +/// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=1) +/// Only input of rank of 3 is supported with softmax axis = 2 #[derive(Clone, Debug, Hash)] pub struct BasicScaledMaskedSoftmax { - pub axis: usize, pub scale: Arc, } @@ -15,7 +20,7 @@ impl Op for BasicScaledMaskedSoftmax { "BasicScaledMaskedSoftmax".to_string().into() } fn info(&self) -> TractResult> { - Ok(vec![format!("axis: {:?}, scale: {:?}", self.axis, self.scale)]) + Ok(vec![format!("scale: {:?}", self.scale)]) } op_as_typed_op!(); } @@ -28,13 +33,13 @@ impl EvalOp for BasicScaledMaskedSoftmax { fn eval(&self, inputs: TVec) -> TractResult> { let (input, mask) = args_2!(inputs); let dt = input.datum_type(); - ensure!(input.shape() == mask.shape()); - let scaled_input = Mul.eval(input, self.scale.clone().into_tvalue(), dt)?; let masked_input = Add.eval(scaled_input.into(), mask, dt)?; - let softmax = Softmax::new(tvec![self.axis], None, SoftmaxExp::Libc) - .eval(tvec![masked_input.into()])?[0]; - Ok(tvec![softmax.into()]) + let softmax = Softmax::new(tvec![2], None, SoftmaxExp::Libc) + .eval(tvec![masked_input.into()])?[0] + .clone() + .into(); + Ok(tvec![softmax]) } } @@ -43,6 +48,7 @@ impl TypedOp for BasicScaledMaskedSoftmax { ensure!(inputs.len() == 2); let (input, mask) = (inputs[0], inputs[1]); ensure!(input.datum_type == mask.datum_type); + ensure!(input.rank() == 3 && mask.rank() == 3); let dt = input.datum_type; let fact = dt.fact(input.shape.clone()); Ok(tvec!(fact)) @@ -50,3 +56,64 @@ impl TypedOp for BasicScaledMaskedSoftmax { as_op!(); } + +/// Search pattern => A = SOFTMAX(A * SCALE + MASK) +pub fn as_scaled_masked_softmax_rule( + _ctx: &(), + model: &TypedModel, + node: &TypedNode, + node_name: &str, + op: &Softmax, +) -> TractResult> { + rule_ensure!(op.axes.as_slice() == &[2]); + + let in_fact = model.node_input_facts(node.id)?[0]; + let dt = in_fact.datum_type; + // Only F16 and F32 is supported. + rule_ensure!(matches!(dt, DatumType::F32 | DatumType::F16)); + + // Identify Add operator (Mask) + let Some(add_prev) = previous_node(model, node) else { return Ok(None) }; + let Some(add_prev_op) = add_prev.op_as::() else { return Ok(None) }; + rule_ensure!(add_prev_op.0.is::()); + + let mut in_add = previous_nodes(model, add_prev); + rule_ensure!(in_add.len() == 2); + + in_add.reverse(); + let (left, right) = (in_add.pop().unwrap(), in_add.pop().unwrap()); + + let (scale_node, mask_outlet) = if left.op_is::() { + (left, add_prev.inputs[1]) + } else { + (right, add_prev.inputs[0]) + }; + + let Some(scale_op) = scale_node.op_as::() else { return Ok(None) }; + rule_ensure!(scale_op.0.is::()); + + // Retrieve Scale + let mul_consts = collect_node_const_inputs(model, scale_node); + rule_ensure!(mul_consts.len() == 1); + let scale = mul_consts[0].0.clone(); + + rule_ensure!(scale.len() == 1); + rule_ensure!(scale.datum_type() == dt); + + // Ensure input and mask have the same rank + rule_ensure!(model.outlet_fact(scale_node.inputs[0])?.shape.rank() == 3); + rule_ensure!(model.outlet_fact(mask_outlet)?.shape.rank() == 3); + + let mut patch = TypedModelPatch::default(); + let input = patch.taps(model, &scale_node.inputs)?[0]; + let mask = patch.taps(model, &[mask_outlet])?[0]; + + let out = patch.wire_node( + format!("{node_name}.scaled_masked_softmax"), + BasicScaledMaskedSoftmax { scale }, + &[input, mask], + )?; + + patch.shunt_outside(model, node.id.into(), out[0])?; + Ok(Some(patch)) +} diff --git a/metal/src/transform.rs b/metal/src/transform.rs index aa65b5b5bf..b5a7590003 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -1,14 +1,17 @@ use crate::fact::MetalTypedFactExt; use crate::kernels::array::RotateHalf; use crate::kernels::matmul::{MetalGemmImplKind, MfaGemm, MlxGemm, MpsMatMul}; -use crate::kernels::nn::{ApplyRope, NewGelu, Reducer, RmsNorm, Silu, Softmax}; +use crate::kernels::nn::{ + ApplyRope, NewGelu, Reducer, RmsNorm, ScaledMaskedSoftmax, Silu, Softmax, +}; use crate::ops::{self, MetalAxisOp, MetalSync, MetalSyncKind}; #[allow(unused_imports)] use crate::rewrite_rules::{ - as_apply_rope_rule, as_new_gelu_rule, as_rms_norm_rule, as_rotate_half_rule, as_silu_rule, - fuse_axis_op, remove_rms_norm_cast, rewire_metal_sync, rewire_metal_sync_after_const, - BasicApplyRope, BasicNewGelu, BasicRmsNorm, BasicRotateHalf, BasicSilu, + as_apply_rope_rule, as_new_gelu_rule, as_rms_norm_rule, as_rotate_half_rule, + as_scaled_masked_softmax_rule, as_silu_rule, fuse_axis_op, remove_rms_norm_cast, + rewire_metal_sync, rewire_metal_sync_after_const, BasicApplyRope, BasicNewGelu, BasicRmsNorm, + BasicRotateHalf, BasicScaledMaskedSoftmax, BasicSilu, }; use crate::tensor::MetalTensorExt; use crate::{IntoMetal, MetalFact, MetalTensor}; @@ -64,6 +67,7 @@ impl ModelTransform for MetalTransform { .with_rule_for::("as-new-gelu", as_new_gelu_rule) .with_rule_for::("as-rotate-half", as_rotate_half_rule) .with_rule_for::("as-apply-rope", as_apply_rope_rule) + .with_rule_for::("as-scaled-masked-softmax", as_scaled_masked_softmax_rule) .rewrite(&(), model)?; let mut new = self.translate_model(model)?; @@ -209,6 +213,10 @@ impl Translate, TypedFact, Box> for Met .then(|| ops::MetalSoftmax::from_tract_core(op).ok()) .flatten() .map(|o| -> Box { Box::new(o) }) + } else if let Some(op) = node.op_as::() { + check_in_dts_are_supported(source, node.id, ScaledMaskedSoftmax::is_supported_dt)? + .then(|| ops::MetalScaledMaskedSoftmax { scale: op.scale.clone() }) + .map(|o| -> Box { Box::new(o) }) } else if let Some(op) = node.op_as::() { check_in_dts_are_supported(source, node.id, RmsNorm::is_supported_dt)? .then(|| ops::MetalRmsNorm::new(op.axis, op.eps.clone())) From c22bb429d5fb46a62857dcc4dc9bd4885146e471 Mon Sep 17 00:00:00 2001 From: Hubert de La Jonquiere Date: Wed, 11 Dec 2024 11:22:16 +0100 Subject: [PATCH 4/6] More test --- metal/src/kernels/nn/nn_ops.metal | 10 +- metal/src/kernels/nn/scaled_masked_softmax.rs | 302 ++++++++---------- metal/src/kernels/nn/softmax.rs | 1 - metal/src/ops/scaled_masked_softmax.rs | 2 +- .../rewrite_rules/scaled_masked_softmax.rs | 5 +- 5 files changed, 147 insertions(+), 173 deletions(-) diff --git a/metal/src/kernels/nn/nn_ops.metal b/metal/src/kernels/nn/nn_ops.metal index c38be59a40..8760c4dc3a 100644 --- a/metal/src/kernels/nn/nn_ops.metal +++ b/metal/src/kernels/nn/nn_ops.metal @@ -295,7 +295,7 @@ template F scale = ((constant F *)scale_b)[0]; device F *output = (device F *)output_b; - size_t dim = shape[2]; + size_t reduce_dim = shape[2]; size_t base_idx = tgpig.y * strides[1] + tgpig.z * strides[0]; @@ -303,9 +303,9 @@ template size_t mask_base_idx = tgpig.y * mask_strides[1] + tgpig.z * mask_strides[0]; - // Get max value on softmax dim after apply + // Get max value on softmax reduce_dim after apply float partial_max = -INFINITY; - for (size_t i = tiisg; i < dim; i += tpsg) { + for (size_t i = tiisg; i < reduce_dim; i += tpsg) { auto idx = base_idx + i * strides[2]; auto mask_idx = mask_base_idx + i * mask_strides[2]; output[idx] = input[idx] * scale + mask[mask_idx]; @@ -317,7 +317,7 @@ template // Compute Sum(exp(x - max)) float partial_norm = 0; - for (size_t i = tiisg; i < dim; i += tpsg) { + for (size_t i = tiisg; i < reduce_dim; i += tpsg) { auto idx = base_idx + i * strides[2]; float el = static_cast(output[idx]); float exp_el = fast::exp(el - axis_max); @@ -327,7 +327,7 @@ template float axis_norm = simd_sum(partial_norm); float inv_axis_norm = 1.0 / axis_norm; - for (size_t i = tiisg; i < dim; i += tpsg) { + for (size_t i = tiisg; i < reduce_dim; i += tpsg) { auto idx = base_idx + i * strides[2]; float el = static_cast(output[idx]); float exp_el = fast::exp(el - axis_max); diff --git a/metal/src/kernels/nn/scaled_masked_softmax.rs b/metal/src/kernels/nn/scaled_masked_softmax.rs index a22059ba3e..b38e6ec0e4 100644 --- a/metal/src/kernels/nn/scaled_masked_softmax.rs +++ b/metal/src/kernels/nn/scaled_masked_softmax.rs @@ -30,7 +30,6 @@ impl ScaledMaskedSoftmax { mask: &MetalTensor, ) -> Result { let output = unsafe { MetalTensor::uninitialized_dt(input.datum_type(), input.shape())? }; - dbg!(&output); self.dispatch_eval(context, input, scale, mask, &output)?; context.wait_until_completed()?; Ok(output) @@ -56,7 +55,7 @@ impl ScaledMaskedSoftmax { let strides = input.strides(); let mask_strides_nd3 = crate::utils::compute_broadcast_strides::(mask.shape(), mask.strides())?; - + let pipeline = context .shared_context() .load_pipeline(LibraryName::NNOps, &self.kernel_name(input.datum_type())?)?; @@ -68,13 +67,11 @@ impl ScaledMaskedSoftmax { encoder.set_metal_tensor(1, mask, metal::MTLResourceUsage::Read); encoder.set_tensor(2, scale); encoder.set_metal_tensor(3, output, metal::MTLResourceUsage::Write); - encoder.set_slice(4, &shape); - encoder.set_slice(5, &strides); + encoder.set_slice(4, shape); + encoder.set_slice(5, strides); encoder.set_slice(6, &mask_strides_nd3); - - let grid_size = MTLSize { width: 1 as _, height: shape[1] as _, depth: shape[0] as _}; + let grid_size = MTLSize { width: 1 as _, height: shape[1] as _, depth: shape[0] as _ }; let group_size = MTLSize { width: usize::min(32, shape[2]) as _, height: 1, depth: 1 }; - encoder.dispatch_thread_groups(grid_size, group_size); encoder.end_encoding(); }); @@ -84,14 +81,15 @@ impl ScaledMaskedSoftmax { #[cfg(test)] mod tests { + use super::*; use crate::rewrite_rules::BasicScaledMaskedSoftmax; -use super::*; use crate::IntoMetal; use derive_new::new; - - - - + use num_traits::AsPrimitive; + use num_traits::Float; + use proptest::collection::vec; + use proptest::prelude::*; + use proptest::strategy::Strategy; use tract_core::internal::Tensor; #[test] @@ -101,171 +99,149 @@ use super::*; let m = 4; let n = 4; let scale: Arc<_> = tensor0(0.125f32).into(); - let mask = Tensor::from_shape(&[1, m, n], &vec![-1000f32; m*n])?.into_metal()?; + let mask = Tensor::from_shape(&[1, m, n], &vec![-1000f32; m * n])?.into_metal()?; - let a = - Tensor::from_shape(&[1, m, n], &(0..m * n).map(|f| f as f32).collect::>())? - .into_metal()?; + let a = Tensor::from_shape( + &[1, m, n], + &(0..m * n).map(|f| f as f32).collect::>(), + )? + .into_metal()?; - let cpu = BasicScaledMaskedSoftmax { scale:scale.clone() }; + let cpu = BasicScaledMaskedSoftmax { scale: scale.clone() }; - let cpu_output = - cpu.eval(tvec![a.to_cpu()?.into_tvalue(), mask.to_cpu()?.into_tvalue()])?[0].clone().into_tensor(); + let cpu_output = cpu + .eval(tvec![a.to_cpu()?.into_tvalue(), mask.to_cpu()?.into_tvalue()])?[0] + .clone() + .into_tensor(); let metal_output = ScaledMaskedSoftmax.eval(context, &a, &scale, &mask)?; cpu_output.close_enough(&metal_output.to_cpu()?, Approximation::Approximate)?; Ok(()) }) }) } -} - -// #[test] -// fn test_softmax_f32_2() -> Result<()> { -// objc::rc::autoreleasepool(|| { -// crate::METAL_CONTEXT.with_borrow(|context| { -// let shape = [8, 4, 3]; -// let num_elements = shape.iter().product(); -// let axis = 0; - -// let a = Tensor::from_shape( -// &shape, -// &(0..num_elements).map(|f| f as f32 / 1000.0).collect::>(), -// )? -// .into_metal()?; - -// let cpu_softmax = TractSoftmax { -// axes: tvec![axis], -// quant_output_dt: None, -// exp: SoftmaxExp::Libc, -// }; - -// let cpu_output = -// cpu_softmax.eval(tvec![a.to_cpu()?.into_tvalue()])?[0].clone().into_tensor(); -// let metal_output = Softmax.eval(context, &a, axis)?; -// cpu_output.close_enough(&metal_output.to_cpu()?, Approximation::Approximate)?; -// Ok(()) -// }) -// }) -// } -// #[test] -// fn test_softmax_f16() -> Result<()> { -// objc::rc::autoreleasepool(|| { -// crate::METAL_CONTEXT.with_borrow(|context| { -// let m = 4; -// let k = 4; -// let axis = 1; - -// let a = Tensor::from_shape( -// &[m, k], -// &(0..m * k).map(|f| -> f16 { f.as_() }).collect::>(), -// )? -// .into_metal()?; - -// let cpu_softmax = TractSoftmax { -// axes: tvec![axis], -// quant_output_dt: None, -// exp: SoftmaxExp::Libc, -// }; - -// let cpu_output = -// cpu_softmax.eval(tvec![a.to_cpu()?.into_tvalue()])?[0].clone().into_tensor(); -// let metal_output = Softmax.eval(context, &a, axis)?; -// cpu_output.close_enough(&metal_output.to_cpu()?, Approximation::Approximate)?; -// Ok(()) -// }) -// }) -// } - -// proptest::proptest! { -// #[test] -// fn softmax_prop_f32(pb in any::>()) { -// fn run(pb: SoftmaxProblem) -> TractResult<()> { -// let out = pb.run()?; -// let reference = pb.reference()?; - -// out.close_enough(&reference, Approximation::Approximate) -// .with_context(|| anyhow!("Cpu: {:?}, Metal: {:?}", reference.dump(true), out.dump(true))) -// } -// run(pb).map_err(|e| TestCaseError::Fail(format!("{:?}", e).into()))?; -// } - -// #[test] -// fn softmax_prop_f16(pb in any::>()) { -// fn run(pb: SoftmaxProblem) -> TractResult<()> { -// let out = pb.run()?; -// let reference = pb.reference()?; - -// out.close_enough(&reference, Approximation::Approximate) -// .with_context(|| anyhow!("Cpu: {:?}, Metal: {:?}", reference.dump(true), out.dump(true))) -// } + #[test] + fn test_scaled_masked_softmax_f32_2() -> Result<()> { + objc::rc::autoreleasepool(|| { + crate::METAL_CONTEXT.with_borrow(|context| { + let m = 4; + let n = 1024; + let scale: Arc<_> = tensor0(0.125f32).into(); + let mask = Tensor::from_shape(&[1, m, n], &vec![-1000f32; m * n])?.into_metal()?; -// run(pb).map_err(|e| TestCaseError::Fail(format!("{:?}", e).into()))?; -// } -// } + let a = Tensor::from_shape( + &[1, m, n], + &(0..m * n).map(|f| f as f32).collect::>(), + )? + .into_metal()?; -// #[derive(Debug, new)] -// pub struct SoftmaxProblem -// where -// F: Datum + Float, -// usize: AsPrimitive, -// { -// pub shape: Vec, -// pub axis: usize, -// pub input: Vec, -// } + let cpu = BasicScaledMaskedSoftmax { scale: scale.clone() }; -// impl Arbitrary for SoftmaxProblem -// where -// F: Datum + Float, -// usize: AsPrimitive, -// { -// type Parameters = (); -// type Strategy = BoxedStrategy; + let cpu_output = cpu + .eval(tvec![a.to_cpu()?.into_tvalue(), mask.to_cpu()?.into_tvalue()])?[0] + .clone() + .into_tensor(); + let metal_output = ScaledMaskedSoftmax.eval(context, &a, &scale, &mask)?; + cpu_output.close_enough(&metal_output.to_cpu()?, Approximation::Approximate)?; + Ok(()) + }) + }) + } -// fn arbitrary_with(_: ()) -> Self::Strategy { -// (0usize..3, 0usize..3) -// .prop_flat_map(|(left, right)| { -// let axis = left; -// let shape_len = usize::min(left + right + 1, 4); -// let shape = 1usize..10; -// (vec(shape, shape_len..=shape_len), Just(axis)) -// }) -// .prop_map(|(shape, axis)| { -// let input = (0..shape.iter().product::()) -// .map(|f| f.as_() / 1000.as_()) -// .collect::>(); -// Self { shape, axis, input } -// }) -// .boxed() -// } -// } + proptest::proptest! { + #[test] + fn scaled_masked_softmax_prop_f32(pb in any::>()) { + fn run(pb: ScaledMaskedSoftmaxProblem) -> TractResult<()> { + let out = pb.run()?; + let reference = pb.reference()?; + + out.close_enough(&reference, Approximation::Approximate) + .with_context(|| anyhow!("Cpu: {:?}, Metal: {:?}", reference.dump(true), out.dump(true))) + } + run(pb).map_err(|e| TestCaseError::Fail(format!("{:?}", e).into()))?; + } + + #[test] + fn scaled_masked_softmax_prop_f16(pb in any::>()) { + fn run(pb: ScaledMaskedSoftmaxProblem) -> TractResult<()> { + let out = pb.run()?; + let reference = pb.reference()?; + + out.close_enough(&reference, Approximation::Approximate) + .with_context(|| anyhow!("Cpu: {:?}, Metal: {:?}", reference.dump(true), out.dump(true))) + } + + run(pb).map_err(|e| TestCaseError::Fail(format!("{:?}", e).into()))?; + } + } -// impl SoftmaxProblem -// where -// F: Datum + Float + std::ops::AddAssign, -// usize: AsPrimitive, -// { -// pub fn reference(&self) -> Result { -// let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?; + #[derive(Debug, new)] + pub struct ScaledMaskedSoftmaxProblem + where + F: Datum + Float, + usize: AsPrimitive, + { + pub shape: Vec, + pub mask_shape: Vec, + pub input: Vec, + pub mask: Vec, + } -// let cpu_softmax = TractSoftmax { -// axes: tvec![self.axis], -// quant_output_dt: None, -// exp: SoftmaxExp::Libc, -// }; -// let cpu_output = cpu_softmax.eval(tvec![a.into_tvalue()])?[0].clone().into_tensor(); -// Ok(cpu_output) -// } + impl Arbitrary for ScaledMaskedSoftmaxProblem + where + F: Datum + Float, + usize: AsPrimitive, + { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_: ()) -> Self::Strategy { + vec(1usize..10, 3..=3) + .prop_map(|shape| { + let mut mask_shape = shape.clone(); + mask_shape[0] = 1; + + let input = (0..shape.iter().product::()) + .map(|f| f.as_() / 1000.as_()) + .collect::>(); + + let mask = (0..mask_shape.iter().product::()) + .map(|f| f.as_() / 1000.as_()) + .collect::>(); + Self { shape, input, mask_shape, mask } + }) + .boxed() + } + } -// pub fn run(&self) -> Result { -// objc::rc::autoreleasepool(|| { -// crate::METAL_CONTEXT.with_borrow(|context| { -// let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?.into_metal()?; -// let metal_output = Softmax.eval(context, &a, self.axis)?; -// metal_output.to_cpu() -// }) -// }) -// } -// } -// } + impl ScaledMaskedSoftmaxProblem + where + F: Datum + Float + std::ops::AddAssign, + usize: AsPrimitive, + { + pub fn reference(&self) -> Result { + let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?; + let mask = Tensor::from_shape(self.mask_shape.as_slice(), &self.mask)?; + let scale: Arc<_> = tensor0(0.125f32).into(); + + let cpu_output = BasicScaledMaskedSoftmax { scale } + .eval(tvec![a.into_tvalue(), mask.into_tvalue()])?[0] + .clone() + .into_tensor(); + Ok(cpu_output) + } + + pub fn run(&self) -> Result { + objc::rc::autoreleasepool(|| { + crate::METAL_CONTEXT.with_borrow(|context| { + let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?.into_metal()?; + let mask = + Tensor::from_shape(self.mask_shape.as_slice(), &self.mask)?.into_metal()?; + let scale: Arc<_> = tensor0(0.125f32).into(); + let metal_output = ScaledMaskedSoftmax.eval(context, &a, &scale, &mask)?; + metal_output.to_cpu() + }) + }) + } + } +} diff --git a/metal/src/kernels/nn/softmax.rs b/metal/src/kernels/nn/softmax.rs index 6024c1d911..34a19bd3fb 100644 --- a/metal/src/kernels/nn/softmax.rs +++ b/metal/src/kernels/nn/softmax.rs @@ -63,7 +63,6 @@ impl Softmax { MTLSize { width: shape_nd3[2] as _, height: 1, depth: shape_nd3[0] as _ }; let group_size = MTLSize { width: usize::min(32, shape_nd3[1]) as _, height: 1, depth: 1 }; - encoder.dispatch_thread_groups(grid_size, group_size); encoder.end_encoding(); }); diff --git a/metal/src/ops/scaled_masked_softmax.rs b/metal/src/ops/scaled_masked_softmax.rs index 4ac66349f5..28656059c0 100644 --- a/metal/src/ops/scaled_masked_softmax.rs +++ b/metal/src/ops/scaled_masked_softmax.rs @@ -6,7 +6,7 @@ use derive_new::new; use tract_core::internal::*; /// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=1) -/// Only input of rank of 3 is supported and softmax axis = 1 +/// Only input of rank of 3 is supported and softmax axis = 2 #[derive(Clone, Debug, new, Hash)] pub struct MetalScaledMaskedSoftmax { pub scale: Arc, diff --git a/metal/src/rewrite_rules/scaled_masked_softmax.rs b/metal/src/rewrite_rules/scaled_masked_softmax.rs index 45244f206b..8e1ca4fd1f 100644 --- a/metal/src/rewrite_rules/scaled_masked_softmax.rs +++ b/metal/src/rewrite_rules/scaled_masked_softmax.rs @@ -37,8 +37,7 @@ impl EvalOp for BasicScaledMaskedSoftmax { let masked_input = Add.eval(scaled_input.into(), mask, dt)?; let softmax = Softmax::new(tvec![2], None, SoftmaxExp::Libc) .eval(tvec![masked_input.into()])?[0] - .clone() - .into(); + .clone(); Ok(tvec![softmax]) } } @@ -65,7 +64,7 @@ pub fn as_scaled_masked_softmax_rule( node_name: &str, op: &Softmax, ) -> TractResult> { - rule_ensure!(op.axes.as_slice() == &[2]); + rule_ensure!(op.axes.as_slice() == [2]); let in_fact = model.node_input_facts(node.id)?[0]; let dt = in_fact.datum_type; From 1c859391ce461cbd3561a969afed8cfcd57c4432 Mon Sep 17 00:00:00 2001 From: Hubert de La Jonquiere Date: Wed, 11 Dec 2024 11:23:37 +0100 Subject: [PATCH 5/6] Fix typo in comment --- metal/src/kernels/nn/nn_ops.metal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metal/src/kernels/nn/nn_ops.metal b/metal/src/kernels/nn/nn_ops.metal index 8760c4dc3a..2f82533490 100644 --- a/metal/src/kernels/nn/nn_ops.metal +++ b/metal/src/kernels/nn/nn_ops.metal @@ -303,7 +303,7 @@ template size_t mask_base_idx = tgpig.y * mask_strides[1] + tgpig.z * mask_strides[0]; - // Get max value on softmax reduce_dim after apply + // Get max value on softmax reduce_dim after applying scale and mask float partial_max = -INFINITY; for (size_t i = tiisg; i < reduce_dim; i += tpsg) { auto idx = base_idx + i * strides[2]; From 48e3947a4559114e2d0f9b01b2d063d853730663 Mon Sep 17 00:00:00 2001 From: Hubert de La Jonquiere Date: Wed, 11 Dec 2024 11:25:30 +0100 Subject: [PATCH 6/6] Improve comments --- metal/src/ops/scaled_masked_softmax.rs | 4 ++-- metal/src/rewrite_rules/scaled_masked_softmax.rs | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/metal/src/ops/scaled_masked_softmax.rs b/metal/src/ops/scaled_masked_softmax.rs index 28656059c0..5715385ef2 100644 --- a/metal/src/ops/scaled_masked_softmax.rs +++ b/metal/src/ops/scaled_masked_softmax.rs @@ -5,8 +5,8 @@ use crate::MetalContext; use derive_new::new; use tract_core::internal::*; -/// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=1) -/// Only input of rank of 3 is supported and softmax axis = 2 +/// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=2) +/// Only input of rank of 3 is supported #[derive(Clone, Debug, new, Hash)] pub struct MetalScaledMaskedSoftmax { pub scale: Arc, diff --git a/metal/src/rewrite_rules/scaled_masked_softmax.rs b/metal/src/rewrite_rules/scaled_masked_softmax.rs index 8e1ca4fd1f..afdd3f8111 100644 --- a/metal/src/rewrite_rules/scaled_masked_softmax.rs +++ b/metal/src/rewrite_rules/scaled_masked_softmax.rs @@ -8,8 +8,8 @@ use tract_core::ops::binary::BinMiniOp; use tract_core::ops::math::{Add, Mul}; use tract_core::ops::nn::{Softmax, SoftmaxExp}; -/// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=1) -/// Only input of rank of 3 is supported with softmax axis = 2 +/// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=2) +/// Only input of rank of 3 is supported. #[derive(Clone, Debug, Hash)] pub struct BasicScaledMaskedSoftmax { pub scale: Arc, @@ -56,7 +56,7 @@ impl TypedOp for BasicScaledMaskedSoftmax { as_op!(); } -/// Search pattern => A = SOFTMAX(A * SCALE + MASK) +/// Search pattern => A = SOFTMAX(A * SCALE + MASK, AXIS=2) pub fn as_scaled_masked_softmax_rule( _ctx: &(), model: &TypedModel,