Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metal] Scaled masked softmax #1602

Merged
merged 6 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions metal/src/kernels/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ pub mod apply_rope;
pub mod new_gelu;
pub mod reduce;
pub mod rms_norm;
pub mod scaled_masked_softmax;
pub mod silu;
pub mod softmax;

pub use apply_rope::ApplyRope;
pub use new_gelu::NewGelu;
pub use reduce::Reducer;
pub use rms_norm::RmsNorm;
pub use scaled_masked_softmax::ScaledMaskedSoftmax;
pub use silu::Silu;
pub use softmax::Softmax;

Expand All @@ -28,5 +30,11 @@ pub fn all_functions() -> Vec<String> {
.flat_map(|dt| Softmax.kernel_name(dt).into_iter()),
);

functions.extend(
crate::MetalTensor::SUPPORTED_DT
.into_iter()
.flat_map(|dt| ScaledMaskedSoftmax.kernel_name(dt).into_iter()),
);

functions.into_iter().collect()
}
64 changes: 64 additions & 0 deletions metal/src/kernels/nn/nn_ops.metal
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,70 @@ typedef decltype(softmax_nd3<float>) softmax_nd3_t;
template [[host_name("nn_ops::softmax_nd3_f32")]] [[kernel]] softmax_nd3_t softmax_nd3<float>;
template [[host_name("nn_ops::softmax_nd3_f16")]] [[kernel]] softmax_nd3_t softmax_nd3<half>;

template<typename F>
[[kernel]] void scaled_masked_softmax_nd3(
device const void *input_b,
device const void *mask_b,
constant void *scale_b,
device void *output_b,
constant const size_t shape[3],
constant const size_t strides[3],
constant const size_t mask_strides[3],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint tpsg[[threads_per_simdgroup]]
) {

device const F *input = (device const F *)input_b;
device const F *mask = (device const F *)mask_b;
F scale = ((constant F *)scale_b)[0];
device F *output = (device F *)output_b;

size_t reduce_dim = shape[2];

size_t base_idx = tgpig.y * strides[1]
+ tgpig.z * strides[0];

size_t mask_base_idx = tgpig.y * mask_strides[1]
+ tgpig.z * mask_strides[0];

// Get max value on softmax reduce_dim after applying scale and mask
float partial_max = -INFINITY;
for (size_t i = tiisg; i < reduce_dim; i += tpsg) {
auto idx = base_idx + i * strides[2];
auto mask_idx = mask_base_idx + i * mask_strides[2];
output[idx] = input[idx] * scale + mask[mask_idx];
float el = static_cast<float>(output[idx]);
partial_max = max(partial_max, el);
}

float axis_max = simd_max(partial_max);

// Compute Sum(exp(x - max))
float partial_norm = 0;
for (size_t i = tiisg; i < reduce_dim; i += tpsg) {
auto idx = base_idx + i * strides[2];
float el = static_cast<float>(output[idx]);
float exp_el = fast::exp(el - axis_max);
partial_norm += exp_el;
}

float axis_norm = simd_sum(partial_norm);
float inv_axis_norm = 1.0 / axis_norm;

for (size_t i = tiisg; i < reduce_dim; i += tpsg) {
auto idx = base_idx + i * strides[2];
float el = static_cast<float>(output[idx]);
float exp_el = fast::exp(el - axis_max);
output[idx] = static_cast<F>(exp_el * inv_axis_norm);
}
}

typedef decltype(scaled_masked_softmax_nd3<float>) scaled_masked_softmax_nd3_t;

template [[host_name("nn_ops::scaled_masked_softmax_nd3_f32")]] [[kernel]] scaled_masked_softmax_nd3_t scaled_masked_softmax_nd3<float>;
template [[host_name("nn_ops::scaled_masked_softmax_nd3_f16")]] [[kernel]] scaled_masked_softmax_nd3_t scaled_masked_softmax_nd3<half>;

constant float GELU_COEF_A = 0.044715f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;

Expand Down
247 changes: 247 additions & 0 deletions metal/src/kernels/nn/scaled_masked_softmax.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
use crate::encoder::EncoderExt;
use crate::{LibraryName, MetalContext, MetalTensor};
use anyhow::Result;
use metal::MTLSize;
use tract_core::internal::*;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ScaledMaskedSoftmax;

impl ScaledMaskedSoftmax {
pub fn is_supported_dt(dt: DatumType) -> bool {
matches!(dt, DatumType::F32 | DatumType::F16)
}

pub fn kernel_name(&self, dt: DatumType) -> Result<String> {
ensure!(
Self::is_supported_dt(dt),
"Unsupport dt {:?} for metal scaled masked softmax op",
dt
);
let tname = MetalTensor::tname(dt)?;
Ok(format!("nn_ops::scaled_masked_softmax_nd3_{tname}"))
}

pub fn eval(
&self,
context: &MetalContext,
input: &MetalTensor,
scale: &Tensor,
mask: &MetalTensor,
) -> Result<MetalTensor> {
let output = unsafe { MetalTensor::uninitialized_dt(input.datum_type(), input.shape())? };
self.dispatch_eval(context, input, scale, mask, &output)?;
context.wait_until_completed()?;
Ok(output)
}

pub fn dispatch_eval(
&self,
context: &MetalContext,
input: &MetalTensor,
scale: &Tensor,
mask: &MetalTensor,
output: &MetalTensor,
) -> Result<()> {
input.retained_until_completion();
mask.retained_until_completion();
output.retained_until_completion();

ensure!(output.shape() == input.shape());
ensure!(mask.rank() == 3 && input.rank() == 3);
ensure!(output.datum_type() == input.datum_type());

let shape = input.shape();
let strides = input.strides();
let mask_strides_nd3 =
crate::utils::compute_broadcast_strides::<usize>(mask.shape(), mask.strides())?;

let pipeline = context
.shared_context()
.load_pipeline(LibraryName::NNOps, &self.kernel_name(input.datum_type())?)?;

let command_buffer = context.command_buffer();
command_buffer.encode(|encoder| {
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_metal_tensor(0, input, metal::MTLResourceUsage::Read);
encoder.set_metal_tensor(1, mask, metal::MTLResourceUsage::Read);
encoder.set_tensor(2, scale);
encoder.set_metal_tensor(3, output, metal::MTLResourceUsage::Write);
encoder.set_slice(4, shape);
encoder.set_slice(5, strides);
encoder.set_slice(6, &mask_strides_nd3);
let grid_size = MTLSize { width: 1 as _, height: shape[1] as _, depth: shape[0] as _ };
let group_size = MTLSize { width: usize::min(32, shape[2]) as _, height: 1, depth: 1 };
encoder.dispatch_thread_groups(grid_size, group_size);
encoder.end_encoding();
});
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::rewrite_rules::BasicScaledMaskedSoftmax;
use crate::IntoMetal;
use derive_new::new;
use num_traits::AsPrimitive;
use num_traits::Float;
use proptest::collection::vec;
use proptest::prelude::*;
use proptest::strategy::Strategy;
use tract_core::internal::Tensor;

#[test]
fn test_scaled_masked_softmax_f32() -> Result<()> {
objc::rc::autoreleasepool(|| {
crate::METAL_CONTEXT.with_borrow(|context| {
let m = 4;
let n = 4;
let scale: Arc<_> = tensor0(0.125f32).into();
let mask = Tensor::from_shape(&[1, m, n], &vec![-1000f32; m * n])?.into_metal()?;

let a = Tensor::from_shape(
&[1, m, n],
&(0..m * n).map(|f| f as f32).collect::<Vec<_>>(),
)?
.into_metal()?;

let cpu = BasicScaledMaskedSoftmax { scale: scale.clone() };

let cpu_output = cpu
.eval(tvec![a.to_cpu()?.into_tvalue(), mask.to_cpu()?.into_tvalue()])?[0]
.clone()
.into_tensor();
let metal_output = ScaledMaskedSoftmax.eval(context, &a, &scale, &mask)?;
cpu_output.close_enough(&metal_output.to_cpu()?, Approximation::Approximate)?;
Ok(())
})
})
}

#[test]
fn test_scaled_masked_softmax_f32_2() -> Result<()> {
objc::rc::autoreleasepool(|| {
crate::METAL_CONTEXT.with_borrow(|context| {
let m = 4;
let n = 1024;
let scale: Arc<_> = tensor0(0.125f32).into();
let mask = Tensor::from_shape(&[1, m, n], &vec![-1000f32; m * n])?.into_metal()?;

let a = Tensor::from_shape(
&[1, m, n],
&(0..m * n).map(|f| f as f32).collect::<Vec<_>>(),
)?
.into_metal()?;

let cpu = BasicScaledMaskedSoftmax { scale: scale.clone() };

let cpu_output = cpu
.eval(tvec![a.to_cpu()?.into_tvalue(), mask.to_cpu()?.into_tvalue()])?[0]
.clone()
.into_tensor();
let metal_output = ScaledMaskedSoftmax.eval(context, &a, &scale, &mask)?;
cpu_output.close_enough(&metal_output.to_cpu()?, Approximation::Approximate)?;
Ok(())
})
})
}

proptest::proptest! {
#[test]
fn scaled_masked_softmax_prop_f32(pb in any::<ScaledMaskedSoftmaxProblem<f32>>()) {
fn run(pb: ScaledMaskedSoftmaxProblem<f32>) -> TractResult<()> {
let out = pb.run()?;
let reference = pb.reference()?;

out.close_enough(&reference, Approximation::Approximate)
.with_context(|| anyhow!("Cpu: {:?}, Metal: {:?}", reference.dump(true), out.dump(true)))
}
run(pb).map_err(|e| TestCaseError::Fail(format!("{:?}", e).into()))?;
}

#[test]
fn scaled_masked_softmax_prop_f16(pb in any::<ScaledMaskedSoftmaxProblem<f16>>()) {
fn run(pb: ScaledMaskedSoftmaxProblem<f16>) -> TractResult<()> {
let out = pb.run()?;
let reference = pb.reference()?;

out.close_enough(&reference, Approximation::Approximate)
.with_context(|| anyhow!("Cpu: {:?}, Metal: {:?}", reference.dump(true), out.dump(true)))
}

run(pb).map_err(|e| TestCaseError::Fail(format!("{:?}", e).into()))?;
}
}

#[derive(Debug, new)]
pub struct ScaledMaskedSoftmaxProblem<F: Datum + Float>
where
F: Datum + Float,
usize: AsPrimitive<F>,
{
pub shape: Vec<usize>,
pub mask_shape: Vec<usize>,
pub input: Vec<F>,
pub mask: Vec<F>,
}

impl<F> Arbitrary for ScaledMaskedSoftmaxProblem<F>
where
F: Datum + Float,
usize: AsPrimitive<F>,
{
type Parameters = ();
type Strategy = BoxedStrategy<Self>;

fn arbitrary_with(_: ()) -> Self::Strategy {
vec(1usize..10, 3..=3)
.prop_map(|shape| {
let mut mask_shape = shape.clone();
mask_shape[0] = 1;

let input = (0..shape.iter().product::<usize>())
.map(|f| f.as_() / 1000.as_())
.collect::<Vec<_>>();

let mask = (0..mask_shape.iter().product::<usize>())
.map(|f| f.as_() / 1000.as_())
.collect::<Vec<_>>();
Self { shape, input, mask_shape, mask }
})
.boxed()
}
}

impl<F> ScaledMaskedSoftmaxProblem<F>
where
F: Datum + Float + std::ops::AddAssign,
usize: AsPrimitive<F>,
{
pub fn reference(&self) -> Result<Tensor> {
let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?;
let mask = Tensor::from_shape(self.mask_shape.as_slice(), &self.mask)?;
let scale: Arc<_> = tensor0(0.125f32).into();

let cpu_output = BasicScaledMaskedSoftmax { scale }
.eval(tvec![a.into_tvalue(), mask.into_tvalue()])?[0]
.clone()
.into_tensor();
Ok(cpu_output)
}

pub fn run(&self) -> Result<Tensor> {
objc::rc::autoreleasepool(|| {
crate::METAL_CONTEXT.with_borrow(|context| {
let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?.into_metal()?;
let mask =
Tensor::from_shape(self.mask_shape.as_slice(), &self.mask)?.into_metal()?;
let scale: Arc<_> = tensor0(0.125f32).into();
let metal_output = ScaledMaskedSoftmax.eval(context, &a, &scale, &mask)?;
metal_output.to_cpu()
})
})
}
}
}
1 change: 0 additions & 1 deletion metal/src/kernels/nn/softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ impl Softmax {
MTLSize { width: shape_nd3[2] as _, height: 1, depth: shape_nd3[0] as _ };
let group_size =
MTLSize { width: usize::min(32, shape_nd3[1]) as _, height: 1, depth: 1 };

encoder.dispatch_thread_groups(grid_size, group_size);
encoder.end_encoding();
});
Expand Down
2 changes: 2 additions & 0 deletions metal/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod new_gelu;
pub mod reduce;
pub mod rms_norm;
pub mod rotate_half;
pub mod scaled_masked_softmax;
pub mod silu;
pub mod slice;
pub mod softmax;
Expand All @@ -29,6 +30,7 @@ pub use new_gelu::MetalNewGelu;
pub use reduce::MetalReduce;
pub use rms_norm::MetalRmsNorm;
pub use rotate_half::MetalRotateHalf;
pub use scaled_masked_softmax::MetalScaledMaskedSoftmax;
pub use silu::MetalSilu;
pub use slice::MetalSlice;
pub use softmax::MetalSoftmax;
Expand Down
Loading
Loading