From 05d0919216e62ce3a5aa8229105d3849a93f9b44 Mon Sep 17 00:00:00 2001 From: Hubert de La Jonquiere Date: Fri, 13 Dec 2024 10:04:03 +0100 Subject: [PATCH] Fix scaled masked softmax --- metal/src/kernels/nn/scaled_masked_softmax.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/metal/src/kernels/nn/scaled_masked_softmax.rs b/metal/src/kernels/nn/scaled_masked_softmax.rs index b38e6ec0e4..67e2cb3b50 100644 --- a/metal/src/kernels/nn/scaled_masked_softmax.rs +++ b/metal/src/kernels/nn/scaled_masked_softmax.rs @@ -218,11 +218,12 @@ mod tests { where F: Datum + Float + std::ops::AddAssign, usize: AsPrimitive, + f32: AsPrimitive, { pub fn reference(&self) -> Result { let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?; let mask = Tensor::from_shape(self.mask_shape.as_slice(), &self.mask)?; - let scale: Arc<_> = tensor0(0.125f32).into(); + let scale: Arc<_> = tensor0::(0.125f32.as_()).into(); let cpu_output = BasicScaledMaskedSoftmax { scale } .eval(tvec![a.into_tvalue(), mask.into_tvalue()])?[0] @@ -237,7 +238,7 @@ mod tests { let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?.into_metal()?; let mask = Tensor::from_shape(self.mask_shape.as_slice(), &self.mask)?.into_metal()?; - let scale: Arc<_> = tensor0(0.125f32).into(); + let scale: Arc<_> = tensor0::(0.125f32.as_()).into(); let metal_output = ScaledMaskedSoftmax.eval(context, &a, &scale, &mask)?; metal_output.to_cpu() })