-
Notifications
You must be signed in to change notification settings - Fork 215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Wasm SIMD implementation of MatMatMulKer<f32>
#1420
Conversation
@kali: could you please take an initial look? I am happy to revert changes in |
MatMatMulKer<f32>
MatMatMulKer<f32>
Here is the documentation of Wasm SIMD intrinsics: https://doc.rust-lang.org/core/arch/wasm32/index.html |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a very big first step. Thanks a lot for going through with this.
Cargo.toml
Outdated
@@ -103,7 +103,7 @@ num-integer = "0.1.44" | |||
num-traits = "0.2.14" | |||
openblas-src = { version = "0.10", features = ["static"] } | |||
paste = "1.0.5" | |||
proptest = "1.0.0" | |||
proptest = { version = "1.0.0", default-features = false, features = ["std"] } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what feature do we loose here ? is that gonna be a problem ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed a new change to enable some of the default features that are compatible with Wasm.
After that change, for criterion
we will lose only rayon
:
https://github.com/bheisler/criterion.rs/blob/f1ea31a92ff919a455f36b13c9a45fd74559d0fe/Cargo.toml#L74C12-L74C54
For proptest
we will lose fork
and timeout
:
https://github.com/proptest-rs/proptest/blob/a62a348b59f422161cbc5c6910f83f1b3c3e67e5/proptest/Cargo.toml#L22
These are multi-threading features that affect performance.
If that's unacceptable, I can revert all Cargo.toml
changes and keep the current status quo of not running Wasm tests. I think that would be okay because we can require people to run those tests locally when making Wasm-related changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We really need to keep proptest fork around (that's super useful to test kernel under address sanitizing, as the sanitizer will abort the process at the first issue). Can we achieve keeping proptest whole on every non-wasm platform by some cfg() tricks in Cargo.toml ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was also thinking in that direction initially. I couldn't find a way to have a platform-specific dependency in the workspace: rust-lang/cargo#5220
I could explore moving proptest
into a create-local dependency (instead of workspace) and then it should be possible to have platform-specific rules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth a try. If it solves the problem, I'm ok with proptest becoming a crate-level dep.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be fixed in the latest commit.
linalg/src/wasm/tests.rs
Outdated
((r1, actual), (r2, expected)) | ||
} | ||
|
||
proptest! { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the logic to add these extra tests ? Are they covering more than the regular test_mmm_kernel_f32 ? I'm not asking for their removal, just trying to understand if there is something missing in the regular standard kernel test suite.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that was exactly the motivation for adding these tests. I wanted to get some confidence that I didn't introduce bugs. I noticed that test_mmm_kernel_f32
was not covering everything. For example, in leaky relu removing the last line still passes test_mmm_kernel_f32
(IIRC).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make the generic test cases cover these instead ? That would benefit all the kernels, including the future wasm ones that will use a different geometry than 4x4.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll try to add a test that covers the missing leaky relu case that I know of. I don't know how to find all missing cases though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, the generic kernel test suite is a work in progress, it does not cover corner case and will never do. But it is very beneficial to augment it instead of adding tests that are specialized for a given kernel. Now that you have setup the basic building blocks for wasm kernels, I expect we will pretty soon have a half a dozen or more wasm kernels, so believe me, being able to rely on the generic kernel test suite will make a hug lot of sense...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed the existing leaky relu test to multiply by 1 instead of 0 (which was causing the test to be ineffective because multiplying by 0 always gives 0). I also removed the custom Wasm tests.
What should we add to the CI to get the new tests to run? |
Good question! I think we would need to install
I am not good with GitHub CI, but I can take a look to see if it is possible to add such as a step. |
I can try and do that. Working on the CI from a fork is cumbersome. |
So the tests are in place, but as you can see, we have an issue here. (Ignore the problem on nightly, this is a cargo bug that is being fixed). https://github.com/sonos/tract/actions/runs/9264996443/job/25486167296?pr=1420 I checked, and the issue was there without the kernel, so I'm not sure how to proceed, it looks like we're getting into an unreachable!() in wasmtime. Any chance you can help ? (If this is non-trivial, we could move this investigation to a separate PR, fix the tests and rebase this PR.) |
Thanks a lot for adding the CI jobs! Let me try to reproduce the failure locally. |
The tests should pass now. The wasmtime unreachable was due to a panic in rust code. (I saw that after running with |
@kali: please take another look when you have time. The PR is ready to go from my side. |
It does look pretty good. Thanks a lot for contributing this! |
To simplify code review, this PR implements operations only for a
4x4
matrix off32
. This implementation will be generalized to support more operations and types in follow-up PRs.Issue: #1361
Changes
criterion
andproptest
crates because they are not compatible with Wasm. This allows runningcargo test
to test Wasm changes.wasm
module gated bytarget_family = "wasm"
andtarget_feature = "simd128"
.WasmMmm4x4
usingstd::arch::wasm32
intrinsics.test_mmm_kernel_f32!
tests for the new implementation.Benchmarking
The
onnx-mobilenet-v2
example was used to benchmark the new implementation.The example was modified to use a PNG image instead of a JPG image because
jpeg-decoder
returns different image bytes depending on whether Wasm SIMD is enabled or not. See: https://github.com/image-rs/jpeg-decoder/blob/c1a1fe04cc54a5446e57a71ea856afd07cd374b2/src/arch/wasm.rs#L9The example was also modified to measure only the duration of inference (
model.run()
).See the first commit in this PR for these changes that were reverted.
Results of running with
wasmtime -O opt-level=0
:624ms
523ms
(1.2x faster)235ms
(2.6x faster)Results of running with
wasmtime
:338ms
321ms
(1.05x faster)206ms
(1.6x faster)