Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Wasm SIMD implementation of MatMatMulKer<f32> #1420

Merged
merged 12 commits into from
Jun 2, 2024
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ cblas = "0.4"
cc = "1.0.69"
clap = { version = "~3.1", features = [ "cargo" ] }
colorous = "1.0.5"
criterion = "0.4"
criterion = { version = "0.4", default-features = false }
derive-new = "0.5.9"
dinghy-test = "0.6"
downcast-rs = "1.2.0"
Expand Down Expand Up @@ -103,7 +103,7 @@ num-integer = "0.1.44"
num-traits = "0.2.14"
openblas-src = { version = "0.10", features = ["static"] }
paste = "1.0.5"
proptest = "1.0.0"
proptest = { version = "1.0.0", default-features = false, features = ["std"] }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what feature do we loose here ? is that gonna be a problem ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed a new change to enable some of the default features that are compatible with Wasm.

After that change, for criterion we will lose only rayon:
https://github.com/bheisler/criterion.rs/blob/f1ea31a92ff919a455f36b13c9a45fd74559d0fe/Cargo.toml#L74C12-L74C54

For proptest we will lose fork and timeout:
https://github.com/proptest-rs/proptest/blob/a62a348b59f422161cbc5c6910f83f1b3c3e67e5/proptest/Cargo.toml#L22

These are multi-threading features that affect performance.

If that's unacceptable, I can revert all Cargo.toml changes and keep the current status quo of not running Wasm tests. I think that would be okay because we can require people to run those tests locally when making Wasm-related changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We really need to keep proptest fork around (that's super useful to test kernel under address sanitizing, as the sanitizer will abort the process at the first issue). Can we achieve keeping proptest whole on every non-wasm platform by some cfg() tricks in Cargo.toml ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also thinking in that direction initially. I couldn't find a way to have a platform-specific dependency in the workspace: rust-lang/cargo#5220

I could explore moving proptest into a create-local dependency (instead of workspace) and then it should be possible to have platform-specific rules.

Copy link
Collaborator

@kali kali May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth a try. If it solves the problem, I'm ok with proptest becoming a crate-level dep.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed in the latest commit.

prost = "0.11.0"
prost-types = "0.11.0"
py_literal = "0.4.0"
Expand Down
12 changes: 10 additions & 2 deletions linalg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ pub mod arm64;
#[cfg(any(target_arch = "arm", target_arch = "armv7"))]
pub mod arm32;

#[cfg(all(target_family = "wasm", target_feature = "simd128"))]
pub mod wasm;

pub use self::frame::{element_wise, lut, mmm};

use crate::frame::mmm::kernel::MatMatMulKer;
Expand Down Expand Up @@ -74,8 +77,10 @@ pub struct Ops {
pub max_f16: Box<dyn Fn() -> Box<dyn reduce::Reduce<f16>> + Send + Sync>,
pub max_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,

pub softmax2_fastcompact_f16: Box<dyn Fn() -> Box<dyn reduce::MapReduce<f16, f16>> + Send + Sync>,
pub softmax2_fastcompact_f32: Box<dyn Fn() -> Box<dyn reduce::MapReduce<f32, f32>> + Send + Sync>,
pub softmax2_fastcompact_f16:
Box<dyn Fn() -> Box<dyn reduce::MapReduce<f16, f16>> + Send + Sync>,
pub softmax2_fastcompact_f32:
Box<dyn Fn() -> Box<dyn reduce::MapReduce<f32, f32>> + Send + Sync>,
}

impl Ops {
Expand Down Expand Up @@ -154,6 +159,9 @@ pub fn best() -> Ops {
arm32::plug(&mut ops);
#[cfg(target_arch = "aarch64")]
arm64::plug(&mut ops);
#[cfg(all(target_family = "wasm", target_feature = "simd128"))]
wasm::plug(&mut ops);

ops
}

Expand Down
338 changes: 338 additions & 0 deletions linalg/src/wasm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
/// Wasm SIMD implementation of `MatMatMulKer<f32>`
///
/// To run test, you need to install `wasmtime`
/// and export the following environment variables:
/// ```
/// > export RUSTFLAGS='-C target-feature=+simd128'
/// > export CARGO_TARGET_WASM32_WASI_RUNNER=wasmtime
/// > cargo test --target=wasm32-wasi
/// ```
use crate::{
mmm::{FusedKerSpec, MatMatMulKer},
Ops, Scaler,
};

#[derive(Copy, Clone, Debug)]
pub struct WasmMmm4x4();

unsafe impl Send for WasmMmm4x4 {}
unsafe impl Sync for WasmMmm4x4 {}

impl MatMatMulKer<f32> for WasmMmm4x4 {
#[inline(always)]
fn name() -> &'static str {
"wasm_f32_4x4"
}

#[inline(always)]
fn mr() -> usize {
4
}

#[inline(always)]
fn nr() -> usize {
4
}

fn end_padding_packed_a() -> usize {
0
}

fn end_padding_packed_b() -> usize {
0
}

#[inline(always)]
fn alignment_bytes_packed_a() -> usize {
std::mem::size_of::<f32>()
}
#[inline(always)]
fn alignment_bytes_packed_b() -> usize {
std::mem::size_of::<f32>()
}

#[inline(never)]
fn kernel(spec: &[FusedKerSpec<f32>]) -> isize {
unsafe { kernel_f32_4x4(spec) }
}
}

pub fn plug(ops: &mut Ops) {
let impls = vec![WasmMmm4x4::mmm()];
ops.mmm_f32_impls = impls.clone();
ops.mmm_f32 = Box::new(|_m, _k, _n| WasmMmm4x4::mmm());
}

unsafe fn kernel_f32_4x4(spec: &[FusedKerSpec<f32>]) -> isize {
use std::arch::wasm32::*;

// Each of these variables stores a row of the matrix,
// consisting of four packed `f32` numbers.
let mut ab0 = f32x4_splat(0.0);
let mut ab1 = f32x4_splat(0.0);
let mut ab2 = f32x4_splat(0.0);
let mut ab3 = f32x4_splat(0.0);

let mut pnl = spec.as_ptr();

while !pnl.is_null() {
match *pnl {
FusedKerSpec::Done => break,
FusedKerSpec::Clear => {
let a = f32x4_splat(0.0);
ab0 = a;
ab1 = a;
ab2 = a;
ab3 = a;
}
FusedKerSpec::ScalarMin(a) => {
let a = f32x4_splat(a);
ab0 = f32x4_min(a, ab0);
ab1 = f32x4_min(a, ab1);
ab2 = f32x4_min(a, ab2);
ab3 = f32x4_min(a, ab3);
}
FusedKerSpec::ScalarMax(a) => {
let a = f32x4_splat(a);
ab0 = f32x4_max(a, ab0);
ab1 = f32x4_max(a, ab1);
ab2 = f32x4_max(a, ab2);
ab3 = f32x4_max(a, ab3);
}
FusedKerSpec::ScalarAdd(a) => {
let a = f32x4_splat(a);
ab0 = f32x4_add(a, ab0);
ab1 = f32x4_add(a, ab1);
ab2 = f32x4_add(a, ab2);
ab3 = f32x4_add(a, ab3);
}
FusedKerSpec::ScalarMul(a) => {
let a = f32x4_splat(a);
ab0 = f32x4_mul(a, ab0);
ab1 = f32x4_mul(a, ab1);
ab2 = f32x4_mul(a, ab2);
ab3 = f32x4_mul(a, ab3);
}
FusedKerSpec::ScalarSub(a) => {
let a = f32x4_splat(a);
ab0 = f32x4_sub(a, ab0);
ab1 = f32x4_sub(a, ab1);
ab2 = f32x4_sub(a, ab2);
ab3 = f32x4_sub(a, ab3);
}
FusedKerSpec::ScalarSubF(a) => {
let a = f32x4_splat(a);
ab0 = f32x4_sub(ab0, a);
ab1 = f32x4_sub(ab1, a);
ab2 = f32x4_sub(ab2, a);
ab3 = f32x4_sub(ab3, a);
}
FusedKerSpec::LeakyRelu(a) => {
let a = f32x4_splat(a);
let zero = f32x4_splat(0.0);

let mask0 = f32x4_gt(ab0, zero);
ab0 = v128_bitselect(ab0, f32x4_mul(a, ab0), mask0);

let mask1 = f32x4_gt(ab1, zero);
ab1 = v128_bitselect(ab1, f32x4_mul(a, ab1), mask1);

let mask2 = f32x4_gt(ab2, zero);
ab2 = v128_bitselect(ab2, f32x4_mul(a, ab2), mask2);

let mask3 = f32x4_gt(ab3, zero);
ab3 = v128_bitselect(ab3, f32x4_mul(a, ab3), mask3);
}
FusedKerSpec::PerRowMin(row) => {
let row = std::slice::from_raw_parts(row, 4);
ab0 = f32x4_min(f32x4_splat(row[0]), ab0);
ab1 = f32x4_min(f32x4_splat(row[1]), ab1);
ab2 = f32x4_min(f32x4_splat(row[2]), ab2);
ab3 = f32x4_min(f32x4_splat(row[3]), ab3);
}
FusedKerSpec::PerRowMax(row) => {
let row = std::slice::from_raw_parts(row, 4);
ab0 = f32x4_max(f32x4_splat(row[0]), ab0);
ab1 = f32x4_max(f32x4_splat(row[1]), ab1);
ab2 = f32x4_max(f32x4_splat(row[2]), ab2);
ab3 = f32x4_max(f32x4_splat(row[3]), ab3);
}
FusedKerSpec::PerRowAdd(row) => {
let row = std::slice::from_raw_parts(row, 4);
ab0 = f32x4_add(f32x4_splat(row[0]), ab0);
ab1 = f32x4_add(f32x4_splat(row[1]), ab1);
ab2 = f32x4_add(f32x4_splat(row[2]), ab2);
ab3 = f32x4_add(f32x4_splat(row[3]), ab3);
}
FusedKerSpec::PerRowMul(row) => {
let row = std::slice::from_raw_parts(row, 4);
ab0 = f32x4_mul(f32x4_splat(row[0]), ab0);
ab1 = f32x4_mul(f32x4_splat(row[1]), ab1);
ab2 = f32x4_mul(f32x4_splat(row[2]), ab2);
ab3 = f32x4_mul(f32x4_splat(row[3]), ab3);
}
FusedKerSpec::PerRowSub(row) => {
let row = std::slice::from_raw_parts(row, 4);
ab0 = f32x4_sub(f32x4_splat(row[0]), ab0);
ab1 = f32x4_sub(f32x4_splat(row[1]), ab1);
ab2 = f32x4_sub(f32x4_splat(row[2]), ab2);
ab3 = f32x4_sub(f32x4_splat(row[3]), ab3);
}
FusedKerSpec::PerRowSubF(row) => {
let row = std::slice::from_raw_parts(row, 4);
ab0 = f32x4_sub(ab0, f32x4_splat(row[0]));
ab1 = f32x4_sub(ab1, f32x4_splat(row[1]));
ab2 = f32x4_sub(ab2, f32x4_splat(row[2]));
ab3 = f32x4_sub(ab3, f32x4_splat(row[3]));
}
FusedKerSpec::PerColMin(cols) => {
let cols = v128_load(cols as *const v128);
ab0 = f32x4_min(cols, ab0);
ab1 = f32x4_min(cols, ab1);
ab2 = f32x4_min(cols, ab2);
ab3 = f32x4_min(cols, ab3);
}
FusedKerSpec::PerColMax(cols) => {
let cols = v128_load(cols as *const v128);
ab0 = f32x4_max(cols, ab0);
ab1 = f32x4_max(cols, ab1);
ab2 = f32x4_max(cols, ab2);
ab3 = f32x4_max(cols, ab3);
}
FusedKerSpec::PerColAdd(cols) => {
let cols = v128_load(cols as *const v128);
ab0 = f32x4_add(cols, ab0);
ab1 = f32x4_add(cols, ab1);
ab2 = f32x4_add(cols, ab2);
ab3 = f32x4_add(cols, ab3);
}
FusedKerSpec::PerColMul(cols) => {
let cols = v128_load(cols as *const v128);
ab0 = f32x4_mul(cols, ab0);
ab1 = f32x4_mul(cols, ab1);
ab2 = f32x4_mul(cols, ab2);
ab3 = f32x4_mul(cols, ab3);
}
FusedKerSpec::PerColSub(cols) => {
let cols = v128_load(cols as *const v128);
ab0 = f32x4_sub(cols, ab0);
ab1 = f32x4_sub(cols, ab1);
ab2 = f32x4_sub(cols, ab2);
ab3 = f32x4_sub(cols, ab3);
}
FusedKerSpec::PerColSubF(cols) => {
let cols = v128_load(cols as *const v128);
ab0 = f32x4_sub(ab0, cols);
ab1 = f32x4_sub(ab1, cols);
ab2 = f32x4_sub(ab2, cols);
ab3 = f32x4_sub(ab3, cols);
}
FusedKerSpec::QScale(shift, rp, mult) => {
let scaler = Scaler::from_fuse_params(shift, rp, mult);
let scale = f32x4_splat(scaler.scale);
ab0 = f32x4_mul(scale, ab0);
ab1 = f32x4_mul(scale, ab1);
ab2 = f32x4_mul(scale, ab2);
ab3 = f32x4_mul(scale, ab3);
}
FusedKerSpec::RoundingShiftRight(shift, _rp) => {
let shift = f32x4_splat(2f32.powi(-(shift as i32)));
ab0 = f32x4_mul(shift, ab0);
ab1 = f32x4_mul(shift, ab1);
ab2 = f32x4_mul(shift, ab2);
ab3 = f32x4_mul(shift, ab3);
}
FusedKerSpec::ShiftLeft(shift) => {
let shift = f32x4_splat(2f32.powi(shift as i32));
ab0 = f32x4_mul(shift, ab0);
ab1 = f32x4_mul(shift, ab1);
ab2 = f32x4_mul(shift, ab2);
ab3 = f32x4_mul(shift, ab3);
}
FusedKerSpec::AddUnicast(tile) => {
let mut ptr: *const u8 = tile.ptr;

let m0 = *(ptr as *const f32);
let m1 = *(ptr.offset(tile.col_byte_stride) as *const f32);
let m2 = *(ptr.offset(tile.col_byte_stride * 2) as *const f32);
let m3 = *(ptr.offset(tile.col_byte_stride * 3) as *const f32);
ab0 = f32x4_add(ab0, f32x4(m0, m1, m2, m3));
ptr = ptr.add(tile.row_byte_stride as usize);

let m0 = *(ptr as *const f32);
let m1 = *(ptr.offset(tile.col_byte_stride) as *const f32);
let m2 = *(ptr.offset(tile.col_byte_stride * 2) as *const f32);
let m3 = *(ptr.offset(tile.col_byte_stride * 3) as *const f32);
ab1 = f32x4_add(ab1, f32x4(m0, m1, m2, m3));
ptr = ptr.add(tile.row_byte_stride as usize);

let m0 = *(ptr as *const f32);
let m1 = *(ptr.offset(tile.col_byte_stride) as *const f32);
let m2 = *(ptr.offset(tile.col_byte_stride * 2) as *const f32);
let m3 = *(ptr.offset(tile.col_byte_stride * 3) as *const f32);
ab2 = f32x4_add(ab2, f32x4(m0, m1, m2, m3));
ptr = ptr.add(tile.row_byte_stride as usize);

let m0 = *(ptr as *const f32);
let m1 = *(ptr.offset(tile.col_byte_stride) as *const f32);
let m2 = *(ptr.offset(tile.col_byte_stride * 2) as *const f32);
let m3 = *(ptr.offset(tile.col_byte_stride * 3) as *const f32);
ab3 = f32x4_add(ab3, f32x4(m0, m1, m2, m3));
}
FusedKerSpec::AddRowColProducts(rows, cols) => {
let cols = v128_load(cols as *const v128);
ab0 = f32x4_add(ab0, f32x4_mul(f32x4_splat(*rows.add(0)), cols));
ab1 = f32x4_add(ab1, f32x4_mul(f32x4_splat(*rows.add(1)), cols));
ab2 = f32x4_add(ab2, f32x4_mul(f32x4_splat(*rows.add(2)), cols));
ab3 = f32x4_add(ab3, f32x4_mul(f32x4_splat(*rows.add(3)), cols));
}
FusedKerSpec::Store(tile) => {
let mut ptr: *mut u8 = tile.ptr;

*(ptr as *mut f32) = f32x4_extract_lane::<0>(ab0);
*(ptr.offset(tile.col_byte_stride) as *mut f32) = f32x4_extract_lane::<1>(ab0);
*(ptr.offset(tile.col_byte_stride * 2) as *mut f32) = f32x4_extract_lane::<2>(ab0);
*(ptr.offset(tile.col_byte_stride * 3) as *mut f32) = f32x4_extract_lane::<3>(ab0);
ptr = ptr.add(tile.row_byte_stride as usize);

*(ptr as *mut f32) = f32x4_extract_lane::<0>(ab1);
*(ptr.offset(tile.col_byte_stride) as *mut f32) = f32x4_extract_lane::<1>(ab1);
*(ptr.offset(tile.col_byte_stride * 2) as *mut f32) = f32x4_extract_lane::<2>(ab1);
*(ptr.offset(tile.col_byte_stride * 3) as *mut f32) = f32x4_extract_lane::<3>(ab1);
ptr = ptr.add(tile.row_byte_stride as usize);

*(ptr as *mut f32) = f32x4_extract_lane::<0>(ab2);
*(ptr.offset(tile.col_byte_stride) as *mut f32) = f32x4_extract_lane::<1>(ab2);
*(ptr.offset(tile.col_byte_stride * 2) as *mut f32) = f32x4_extract_lane::<2>(ab2);
*(ptr.offset(tile.col_byte_stride * 3) as *mut f32) = f32x4_extract_lane::<3>(ab2);
ptr = ptr.add(tile.row_byte_stride as usize);

*(ptr as *mut f32) = f32x4_extract_lane::<0>(ab3);
*(ptr.offset(tile.col_byte_stride) as *mut f32) = f32x4_extract_lane::<1>(ab3);
*(ptr.offset(tile.col_byte_stride * 2) as *mut f32) = f32x4_extract_lane::<2>(ab3);
*(ptr.offset(tile.col_byte_stride * 3) as *mut f32) = f32x4_extract_lane::<3>(ab3);
}
FusedKerSpec::AddMatMul { k, pa, pb, cpu_variant: _ } => {
let a = pa as *const f32;
let b = pb as *const v128;
for i in 0..k {
let a = std::slice::from_raw_parts(a.offset(4 * i as isize), 4);
let b = v128_load(b.offset(i as isize));
ab0 = f32x4_add(ab0, f32x4_mul(f32x4_splat(a[0]), b));
ab1 = f32x4_add(ab1, f32x4_mul(f32x4_splat(a[1]), b));
ab2 = f32x4_add(ab2, f32x4_mul(f32x4_splat(a[2]), b));
ab3 = f32x4_add(ab3, f32x4_mul(f32x4_splat(a[3]), b));
}
}
}
pnl = pnl.add(1);
}
0
}

#[allow(non_camel_case_types)]
pub type wasm_f32_4x4 = WasmMmm4x4;
test_mmm_kernel_f32!(wasm_f32_4x4, true);

#[cfg(test)]
mod tests;
Loading
Loading