diff --git a/metal/src/kernels/nn/scaled_masked_softmax.rs b/metal/src/kernels/nn/scaled_masked_softmax.rs index b38e6ec0e4..67e2cb3b50 100644 --- a/metal/src/kernels/nn/scaled_masked_softmax.rs +++ b/metal/src/kernels/nn/scaled_masked_softmax.rs @@ -218,11 +218,12 @@ mod tests { where F: Datum + Float + std::ops::AddAssign, usize: AsPrimitive, + f32: AsPrimitive, { pub fn reference(&self) -> Result { let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?; let mask = Tensor::from_shape(self.mask_shape.as_slice(), &self.mask)?; - let scale: Arc<_> = tensor0(0.125f32).into(); + let scale: Arc<_> = tensor0::(0.125f32.as_()).into(); let cpu_output = BasicScaledMaskedSoftmax { scale } .eval(tvec![a.into_tvalue(), mask.into_tvalue()])?[0] @@ -237,7 +238,7 @@ mod tests { let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?.into_metal()?; let mask = Tensor::from_shape(self.mask_shape.as_slice(), &self.mask)?.into_metal()?; - let scale: Arc<_> = tensor0(0.125f32).into(); + let scale: Arc<_> = tensor0::(0.125f32.as_()).into(); let metal_output = ScaledMaskedSoftmax.eval(context, &a, &scale, &mask)?; metal_output.to_cpu() })