-
Notifications
You must be signed in to change notification settings - Fork 215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Wasm SIMD implementation of MatMatMulKer<f32>
#1420
Merged
Merged
Changes from 2 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
2651a27
feat: Wasm SIMD implementation of `MatMatMulKer<f32>`
ulan 05f63f7
Undo onnx demo changes
ulan 4e9d623
Enable Wasm-compatible default features of criterion and proptest
ulan 1fb428c
setup some tests for wasm
kali b39983e
wasm ci test, take 2
kali 5634b48
path to wasmtime
kali 5d12443
path to wasmtime, again
kali ba5690b
path to wasmtime, again, again
kali 48a4e03
Fix tests that are failing on Wasm
ulan b594433
Use Wasm-specific criterion/proptest when needed
ulan 080b6e3
Improve LeakyRelu test
ulan db14e6e
Remove custom Wasm tests
ulan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,338 @@ | ||
/// Wasm SIMD implementation of `MatMatMulKer<f32>` | ||
/// | ||
/// To run test, you need to install `wasmtime` | ||
/// and export the following environment variables: | ||
/// ``` | ||
/// > export RUSTFLAGS='-C target-feature=+simd128' | ||
/// > export CARGO_TARGET_WASM32_WASI_RUNNER=wasmtime | ||
/// > cargo test --target=wasm32-wasi | ||
/// ``` | ||
use crate::{ | ||
mmm::{FusedKerSpec, MatMatMulKer}, | ||
Ops, Scaler, | ||
}; | ||
|
||
#[derive(Copy, Clone, Debug)] | ||
pub struct WasmMmm4x4(); | ||
|
||
unsafe impl Send for WasmMmm4x4 {} | ||
unsafe impl Sync for WasmMmm4x4 {} | ||
|
||
impl MatMatMulKer<f32> for WasmMmm4x4 { | ||
#[inline(always)] | ||
fn name() -> &'static str { | ||
"wasm_f32_4x4" | ||
} | ||
|
||
#[inline(always)] | ||
fn mr() -> usize { | ||
4 | ||
} | ||
|
||
#[inline(always)] | ||
fn nr() -> usize { | ||
4 | ||
} | ||
|
||
fn end_padding_packed_a() -> usize { | ||
0 | ||
} | ||
|
||
fn end_padding_packed_b() -> usize { | ||
0 | ||
} | ||
|
||
#[inline(always)] | ||
fn alignment_bytes_packed_a() -> usize { | ||
std::mem::size_of::<f32>() | ||
} | ||
#[inline(always)] | ||
fn alignment_bytes_packed_b() -> usize { | ||
std::mem::size_of::<f32>() | ||
} | ||
|
||
#[inline(never)] | ||
fn kernel(spec: &[FusedKerSpec<f32>]) -> isize { | ||
unsafe { kernel_f32_4x4(spec) } | ||
} | ||
} | ||
|
||
pub fn plug(ops: &mut Ops) { | ||
let impls = vec![WasmMmm4x4::mmm()]; | ||
ops.mmm_f32_impls = impls.clone(); | ||
ops.mmm_f32 = Box::new(|_m, _k, _n| WasmMmm4x4::mmm()); | ||
} | ||
|
||
unsafe fn kernel_f32_4x4(spec: &[FusedKerSpec<f32>]) -> isize { | ||
use std::arch::wasm32::*; | ||
|
||
// Each of these variables stores a row of the matrix, | ||
// consisting of four packed `f32` numbers. | ||
let mut ab0 = f32x4_splat(0.0); | ||
let mut ab1 = f32x4_splat(0.0); | ||
let mut ab2 = f32x4_splat(0.0); | ||
let mut ab3 = f32x4_splat(0.0); | ||
|
||
let mut pnl = spec.as_ptr(); | ||
|
||
while !pnl.is_null() { | ||
match *pnl { | ||
FusedKerSpec::Done => break, | ||
FusedKerSpec::Clear => { | ||
let a = f32x4_splat(0.0); | ||
ab0 = a; | ||
ab1 = a; | ||
ab2 = a; | ||
ab3 = a; | ||
} | ||
FusedKerSpec::ScalarMin(a) => { | ||
let a = f32x4_splat(a); | ||
ab0 = f32x4_min(a, ab0); | ||
ab1 = f32x4_min(a, ab1); | ||
ab2 = f32x4_min(a, ab2); | ||
ab3 = f32x4_min(a, ab3); | ||
} | ||
FusedKerSpec::ScalarMax(a) => { | ||
let a = f32x4_splat(a); | ||
ab0 = f32x4_max(a, ab0); | ||
ab1 = f32x4_max(a, ab1); | ||
ab2 = f32x4_max(a, ab2); | ||
ab3 = f32x4_max(a, ab3); | ||
} | ||
FusedKerSpec::ScalarAdd(a) => { | ||
let a = f32x4_splat(a); | ||
ab0 = f32x4_add(a, ab0); | ||
ab1 = f32x4_add(a, ab1); | ||
ab2 = f32x4_add(a, ab2); | ||
ab3 = f32x4_add(a, ab3); | ||
} | ||
FusedKerSpec::ScalarMul(a) => { | ||
let a = f32x4_splat(a); | ||
ab0 = f32x4_mul(a, ab0); | ||
ab1 = f32x4_mul(a, ab1); | ||
ab2 = f32x4_mul(a, ab2); | ||
ab3 = f32x4_mul(a, ab3); | ||
} | ||
FusedKerSpec::ScalarSub(a) => { | ||
let a = f32x4_splat(a); | ||
ab0 = f32x4_sub(a, ab0); | ||
ab1 = f32x4_sub(a, ab1); | ||
ab2 = f32x4_sub(a, ab2); | ||
ab3 = f32x4_sub(a, ab3); | ||
} | ||
FusedKerSpec::ScalarSubF(a) => { | ||
let a = f32x4_splat(a); | ||
ab0 = f32x4_sub(ab0, a); | ||
ab1 = f32x4_sub(ab1, a); | ||
ab2 = f32x4_sub(ab2, a); | ||
ab3 = f32x4_sub(ab3, a); | ||
} | ||
FusedKerSpec::LeakyRelu(a) => { | ||
let a = f32x4_splat(a); | ||
let zero = f32x4_splat(0.0); | ||
|
||
let mask0 = f32x4_gt(ab0, zero); | ||
ab0 = v128_bitselect(ab0, f32x4_mul(a, ab0), mask0); | ||
|
||
let mask1 = f32x4_gt(ab1, zero); | ||
ab1 = v128_bitselect(ab1, f32x4_mul(a, ab1), mask1); | ||
|
||
let mask2 = f32x4_gt(ab2, zero); | ||
ab2 = v128_bitselect(ab2, f32x4_mul(a, ab2), mask2); | ||
|
||
let mask3 = f32x4_gt(ab3, zero); | ||
ab3 = v128_bitselect(ab3, f32x4_mul(a, ab3), mask3); | ||
} | ||
FusedKerSpec::PerRowMin(row) => { | ||
let row = std::slice::from_raw_parts(row, 4); | ||
ab0 = f32x4_min(f32x4_splat(row[0]), ab0); | ||
ab1 = f32x4_min(f32x4_splat(row[1]), ab1); | ||
ab2 = f32x4_min(f32x4_splat(row[2]), ab2); | ||
ab3 = f32x4_min(f32x4_splat(row[3]), ab3); | ||
} | ||
FusedKerSpec::PerRowMax(row) => { | ||
let row = std::slice::from_raw_parts(row, 4); | ||
ab0 = f32x4_max(f32x4_splat(row[0]), ab0); | ||
ab1 = f32x4_max(f32x4_splat(row[1]), ab1); | ||
ab2 = f32x4_max(f32x4_splat(row[2]), ab2); | ||
ab3 = f32x4_max(f32x4_splat(row[3]), ab3); | ||
} | ||
FusedKerSpec::PerRowAdd(row) => { | ||
let row = std::slice::from_raw_parts(row, 4); | ||
ab0 = f32x4_add(f32x4_splat(row[0]), ab0); | ||
ab1 = f32x4_add(f32x4_splat(row[1]), ab1); | ||
ab2 = f32x4_add(f32x4_splat(row[2]), ab2); | ||
ab3 = f32x4_add(f32x4_splat(row[3]), ab3); | ||
} | ||
FusedKerSpec::PerRowMul(row) => { | ||
let row = std::slice::from_raw_parts(row, 4); | ||
ab0 = f32x4_mul(f32x4_splat(row[0]), ab0); | ||
ab1 = f32x4_mul(f32x4_splat(row[1]), ab1); | ||
ab2 = f32x4_mul(f32x4_splat(row[2]), ab2); | ||
ab3 = f32x4_mul(f32x4_splat(row[3]), ab3); | ||
} | ||
FusedKerSpec::PerRowSub(row) => { | ||
let row = std::slice::from_raw_parts(row, 4); | ||
ab0 = f32x4_sub(f32x4_splat(row[0]), ab0); | ||
ab1 = f32x4_sub(f32x4_splat(row[1]), ab1); | ||
ab2 = f32x4_sub(f32x4_splat(row[2]), ab2); | ||
ab3 = f32x4_sub(f32x4_splat(row[3]), ab3); | ||
} | ||
FusedKerSpec::PerRowSubF(row) => { | ||
let row = std::slice::from_raw_parts(row, 4); | ||
ab0 = f32x4_sub(ab0, f32x4_splat(row[0])); | ||
ab1 = f32x4_sub(ab1, f32x4_splat(row[1])); | ||
ab2 = f32x4_sub(ab2, f32x4_splat(row[2])); | ||
ab3 = f32x4_sub(ab3, f32x4_splat(row[3])); | ||
} | ||
FusedKerSpec::PerColMin(cols) => { | ||
let cols = v128_load(cols as *const v128); | ||
ab0 = f32x4_min(cols, ab0); | ||
ab1 = f32x4_min(cols, ab1); | ||
ab2 = f32x4_min(cols, ab2); | ||
ab3 = f32x4_min(cols, ab3); | ||
} | ||
FusedKerSpec::PerColMax(cols) => { | ||
let cols = v128_load(cols as *const v128); | ||
ab0 = f32x4_max(cols, ab0); | ||
ab1 = f32x4_max(cols, ab1); | ||
ab2 = f32x4_max(cols, ab2); | ||
ab3 = f32x4_max(cols, ab3); | ||
} | ||
FusedKerSpec::PerColAdd(cols) => { | ||
let cols = v128_load(cols as *const v128); | ||
ab0 = f32x4_add(cols, ab0); | ||
ab1 = f32x4_add(cols, ab1); | ||
ab2 = f32x4_add(cols, ab2); | ||
ab3 = f32x4_add(cols, ab3); | ||
} | ||
FusedKerSpec::PerColMul(cols) => { | ||
let cols = v128_load(cols as *const v128); | ||
ab0 = f32x4_mul(cols, ab0); | ||
ab1 = f32x4_mul(cols, ab1); | ||
ab2 = f32x4_mul(cols, ab2); | ||
ab3 = f32x4_mul(cols, ab3); | ||
} | ||
FusedKerSpec::PerColSub(cols) => { | ||
let cols = v128_load(cols as *const v128); | ||
ab0 = f32x4_sub(cols, ab0); | ||
ab1 = f32x4_sub(cols, ab1); | ||
ab2 = f32x4_sub(cols, ab2); | ||
ab3 = f32x4_sub(cols, ab3); | ||
} | ||
FusedKerSpec::PerColSubF(cols) => { | ||
let cols = v128_load(cols as *const v128); | ||
ab0 = f32x4_sub(ab0, cols); | ||
ab1 = f32x4_sub(ab1, cols); | ||
ab2 = f32x4_sub(ab2, cols); | ||
ab3 = f32x4_sub(ab3, cols); | ||
} | ||
FusedKerSpec::QScale(shift, rp, mult) => { | ||
let scaler = Scaler::from_fuse_params(shift, rp, mult); | ||
let scale = f32x4_splat(scaler.scale); | ||
ab0 = f32x4_mul(scale, ab0); | ||
ab1 = f32x4_mul(scale, ab1); | ||
ab2 = f32x4_mul(scale, ab2); | ||
ab3 = f32x4_mul(scale, ab3); | ||
} | ||
FusedKerSpec::RoundingShiftRight(shift, _rp) => { | ||
let shift = f32x4_splat(2f32.powi(-(shift as i32))); | ||
ab0 = f32x4_mul(shift, ab0); | ||
ab1 = f32x4_mul(shift, ab1); | ||
ab2 = f32x4_mul(shift, ab2); | ||
ab3 = f32x4_mul(shift, ab3); | ||
} | ||
FusedKerSpec::ShiftLeft(shift) => { | ||
let shift = f32x4_splat(2f32.powi(shift as i32)); | ||
ab0 = f32x4_mul(shift, ab0); | ||
ab1 = f32x4_mul(shift, ab1); | ||
ab2 = f32x4_mul(shift, ab2); | ||
ab3 = f32x4_mul(shift, ab3); | ||
} | ||
FusedKerSpec::AddUnicast(tile) => { | ||
let mut ptr: *const u8 = tile.ptr; | ||
|
||
let m0 = *(ptr as *const f32); | ||
let m1 = *(ptr.offset(tile.col_byte_stride) as *const f32); | ||
let m2 = *(ptr.offset(tile.col_byte_stride * 2) as *const f32); | ||
let m3 = *(ptr.offset(tile.col_byte_stride * 3) as *const f32); | ||
ab0 = f32x4_add(ab0, f32x4(m0, m1, m2, m3)); | ||
ptr = ptr.add(tile.row_byte_stride as usize); | ||
|
||
let m0 = *(ptr as *const f32); | ||
let m1 = *(ptr.offset(tile.col_byte_stride) as *const f32); | ||
let m2 = *(ptr.offset(tile.col_byte_stride * 2) as *const f32); | ||
let m3 = *(ptr.offset(tile.col_byte_stride * 3) as *const f32); | ||
ab1 = f32x4_add(ab1, f32x4(m0, m1, m2, m3)); | ||
ptr = ptr.add(tile.row_byte_stride as usize); | ||
|
||
let m0 = *(ptr as *const f32); | ||
let m1 = *(ptr.offset(tile.col_byte_stride) as *const f32); | ||
let m2 = *(ptr.offset(tile.col_byte_stride * 2) as *const f32); | ||
let m3 = *(ptr.offset(tile.col_byte_stride * 3) as *const f32); | ||
ab2 = f32x4_add(ab2, f32x4(m0, m1, m2, m3)); | ||
ptr = ptr.add(tile.row_byte_stride as usize); | ||
|
||
let m0 = *(ptr as *const f32); | ||
let m1 = *(ptr.offset(tile.col_byte_stride) as *const f32); | ||
let m2 = *(ptr.offset(tile.col_byte_stride * 2) as *const f32); | ||
let m3 = *(ptr.offset(tile.col_byte_stride * 3) as *const f32); | ||
ab3 = f32x4_add(ab3, f32x4(m0, m1, m2, m3)); | ||
} | ||
FusedKerSpec::AddRowColProducts(rows, cols) => { | ||
let cols = v128_load(cols as *const v128); | ||
ab0 = f32x4_add(ab0, f32x4_mul(f32x4_splat(*rows.add(0)), cols)); | ||
ab1 = f32x4_add(ab1, f32x4_mul(f32x4_splat(*rows.add(1)), cols)); | ||
ab2 = f32x4_add(ab2, f32x4_mul(f32x4_splat(*rows.add(2)), cols)); | ||
ab3 = f32x4_add(ab3, f32x4_mul(f32x4_splat(*rows.add(3)), cols)); | ||
} | ||
FusedKerSpec::Store(tile) => { | ||
let mut ptr: *mut u8 = tile.ptr; | ||
|
||
*(ptr as *mut f32) = f32x4_extract_lane::<0>(ab0); | ||
*(ptr.offset(tile.col_byte_stride) as *mut f32) = f32x4_extract_lane::<1>(ab0); | ||
*(ptr.offset(tile.col_byte_stride * 2) as *mut f32) = f32x4_extract_lane::<2>(ab0); | ||
*(ptr.offset(tile.col_byte_stride * 3) as *mut f32) = f32x4_extract_lane::<3>(ab0); | ||
ptr = ptr.add(tile.row_byte_stride as usize); | ||
|
||
*(ptr as *mut f32) = f32x4_extract_lane::<0>(ab1); | ||
*(ptr.offset(tile.col_byte_stride) as *mut f32) = f32x4_extract_lane::<1>(ab1); | ||
*(ptr.offset(tile.col_byte_stride * 2) as *mut f32) = f32x4_extract_lane::<2>(ab1); | ||
*(ptr.offset(tile.col_byte_stride * 3) as *mut f32) = f32x4_extract_lane::<3>(ab1); | ||
ptr = ptr.add(tile.row_byte_stride as usize); | ||
|
||
*(ptr as *mut f32) = f32x4_extract_lane::<0>(ab2); | ||
*(ptr.offset(tile.col_byte_stride) as *mut f32) = f32x4_extract_lane::<1>(ab2); | ||
*(ptr.offset(tile.col_byte_stride * 2) as *mut f32) = f32x4_extract_lane::<2>(ab2); | ||
*(ptr.offset(tile.col_byte_stride * 3) as *mut f32) = f32x4_extract_lane::<3>(ab2); | ||
ptr = ptr.add(tile.row_byte_stride as usize); | ||
|
||
*(ptr as *mut f32) = f32x4_extract_lane::<0>(ab3); | ||
*(ptr.offset(tile.col_byte_stride) as *mut f32) = f32x4_extract_lane::<1>(ab3); | ||
*(ptr.offset(tile.col_byte_stride * 2) as *mut f32) = f32x4_extract_lane::<2>(ab3); | ||
*(ptr.offset(tile.col_byte_stride * 3) as *mut f32) = f32x4_extract_lane::<3>(ab3); | ||
} | ||
FusedKerSpec::AddMatMul { k, pa, pb, cpu_variant: _ } => { | ||
let a = pa as *const f32; | ||
let b = pb as *const v128; | ||
for i in 0..k { | ||
let a = std::slice::from_raw_parts(a.offset(4 * i as isize), 4); | ||
let b = v128_load(b.offset(i as isize)); | ||
ab0 = f32x4_add(ab0, f32x4_mul(f32x4_splat(a[0]), b)); | ||
ab1 = f32x4_add(ab1, f32x4_mul(f32x4_splat(a[1]), b)); | ||
ab2 = f32x4_add(ab2, f32x4_mul(f32x4_splat(a[2]), b)); | ||
ab3 = f32x4_add(ab3, f32x4_mul(f32x4_splat(a[3]), b)); | ||
} | ||
} | ||
} | ||
pnl = pnl.add(1); | ||
} | ||
0 | ||
} | ||
|
||
#[allow(non_camel_case_types)] | ||
pub type wasm_f32_4x4 = WasmMmm4x4; | ||
test_mmm_kernel_f32!(wasm_f32_4x4, true); | ||
|
||
#[cfg(test)] | ||
mod tests; |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what feature do we loose here ? is that gonna be a problem ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed a new change to enable some of the default features that are compatible with Wasm.
After that change, for
criterion
we will lose onlyrayon
:https://github.com/bheisler/criterion.rs/blob/f1ea31a92ff919a455f36b13c9a45fd74559d0fe/Cargo.toml#L74C12-L74C54
For
proptest
we will losefork
andtimeout
:https://github.com/proptest-rs/proptest/blob/a62a348b59f422161cbc5c6910f83f1b3c3e67e5/proptest/Cargo.toml#L22
These are multi-threading features that affect performance.
If that's unacceptable, I can revert all
Cargo.toml
changes and keep the current status quo of not running Wasm tests. I think that would be okay because we can require people to run those tests locally when making Wasm-related changes.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We really need to keep proptest fork around (that's super useful to test kernel under address sanitizing, as the sanitizer will abort the process at the first issue). Can we achieve keeping proptest whole on every non-wasm platform by some cfg() tricks in Cargo.toml ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was also thinking in that direction initially. I couldn't find a way to have a platform-specific dependency in the workspace: rust-lang/cargo#5220
I could explore moving
proptest
into a create-local dependency (instead of workspace) and then it should be possible to have platform-specific rules.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth a try. If it solves the problem, I'm ok with proptest becoming a crate-level dep.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be fixed in the latest commit.