Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Wasm SIMD implementation of MatMatMulKer<f32> #1420

Merged
merged 12 commits into from
Jun 2, 2024
Merged

Conversation

ulan
Copy link
Contributor

@ulan ulan commented May 27, 2024

To simplify code review, this PR implements operations only for a 4x4 matrix of f32. This implementation will be generalized to support more operations and types in follow-up PRs.

Issue: #1361

Changes

  • Disable the default features of thecriterion and proptest crates because they are not compatible with Wasm. This allows running cargo test to test Wasm changes.
  • Add a new wasm module gated by target_family = "wasm" and target_feature = "simd128".
  • Implement WasmMmm4x4 using std::arch::wasm32 intrinsics.
  • Add test_mmm_kernel_f32! tests for the new implementation.
  • Add prop tests that compare the new implementation against the existing generic implementation to gain additional test coverage.

Benchmarking

The onnx-mobilenet-v2 example was used to benchmark the new implementation.

The example was modified to use a PNG image instead of a JPG image because jpeg-decoder returns different image bytes depending on whether Wasm SIMD is enabled or not. See: https://github.com/image-rs/jpeg-decoder/blob/c1a1fe04cc54a5446e57a71ea856afd07cd374b2/src/arch/wasm.rs#L9
The example was also modified to measure only the duration of inference (model.run()).
See the first commit in this PR for these changes that were reverted.

Results of running with wasmtime -O opt-level=0:

  • Baseline: 624ms
  • Wasm SIMD with Rust auto-vectorization: 523ms (1.2x faster)
  • Wasm SIMD with this implementation: 235ms (2.6x faster)

Results of running with wasmtime:

  • Baseline: 338ms
  • Wasm SIMD with Rust auto-vectorization: 321ms (1.05x faster)
  • Wasm SIMD with this implementation: 206ms (1.6x faster)

@ulan
Copy link
Contributor Author

ulan commented May 27, 2024

@kali: could you please take an initial look? I am happy to revert changes in Cargo.toml or remove the new prop tests if you prefer that.

@ulan ulan changed the title Draft: feat: Wasm SIMD implementation of MatMatMulKer<f32> feat: Wasm SIMD implementation of MatMatMulKer<f32> May 27, 2024
@ulan
Copy link
Contributor Author

ulan commented May 27, 2024

Here is the documentation of Wasm SIMD intrinsics: https://doc.rust-lang.org/core/arch/wasm32/index.html
They support 128-bit operations.

Copy link
Collaborator

@kali kali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a very big first step. Thanks a lot for going through with this.

Cargo.toml Outdated
@@ -103,7 +103,7 @@ num-integer = "0.1.44"
num-traits = "0.2.14"
openblas-src = { version = "0.10", features = ["static"] }
paste = "1.0.5"
proptest = "1.0.0"
proptest = { version = "1.0.0", default-features = false, features = ["std"] }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what feature do we loose here ? is that gonna be a problem ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed a new change to enable some of the default features that are compatible with Wasm.

After that change, for criterion we will lose only rayon:
https://github.com/bheisler/criterion.rs/blob/f1ea31a92ff919a455f36b13c9a45fd74559d0fe/Cargo.toml#L74C12-L74C54

For proptest we will lose fork and timeout:
https://github.com/proptest-rs/proptest/blob/a62a348b59f422161cbc5c6910f83f1b3c3e67e5/proptest/Cargo.toml#L22

These are multi-threading features that affect performance.

If that's unacceptable, I can revert all Cargo.toml changes and keep the current status quo of not running Wasm tests. I think that would be okay because we can require people to run those tests locally when making Wasm-related changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We really need to keep proptest fork around (that's super useful to test kernel under address sanitizing, as the sanitizer will abort the process at the first issue). Can we achieve keeping proptest whole on every non-wasm platform by some cfg() tricks in Cargo.toml ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also thinking in that direction initially. I couldn't find a way to have a platform-specific dependency in the workspace: rust-lang/cargo#5220

I could explore moving proptest into a create-local dependency (instead of workspace) and then it should be possible to have platform-specific rules.

Copy link
Collaborator

@kali kali May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth a try. If it solves the problem, I'm ok with proptest becoming a crate-level dep.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed in the latest commit.

((r1, actual), (r2, expected))
}

proptest! {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the logic to add these extra tests ? Are they covering more than the regular test_mmm_kernel_f32 ? I'm not asking for their removal, just trying to understand if there is something missing in the regular standard kernel test suite.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was exactly the motivation for adding these tests. I wanted to get some confidence that I didn't introduce bugs. I noticed that test_mmm_kernel_f32 was not covering everything. For example, in leaky relu removing the last line still passes test_mmm_kernel_f32 (IIRC).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make the generic test cases cover these instead ? That would benefit all the kernels, including the future wasm ones that will use a different geometry than 4x4.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to add a test that covers the missing leaky relu case that I know of. I don't know how to find all missing cases though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, the generic kernel test suite is a work in progress, it does not cover corner case and will never do. But it is very beneficial to augment it instead of adding tests that are specialized for a given kernel. Now that you have setup the basic building blocks for wasm kernels, I expect we will pretty soon have a half a dozen or more wasm kernels, so believe me, being able to rely on the generic kernel test suite will make a hug lot of sense...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the existing leaky relu test to multiply by 1 instead of 0 (which was causing the test to be ineffective because multiplying by 0 always gives 0). I also removed the custom Wasm tests.

@kali
Copy link
Collaborator

kali commented May 27, 2024

What should we add to the CI to get the new tests to run?

@ulan
Copy link
Contributor Author

ulan commented May 27, 2024

What should we add to the CI to get the new tests to run?

Good question! I think we would need to install wasmtime and run in tract/linalg:

RUSTFLAGS='-C target-feature=+simd128' CARGO_TARGET_WASM32_WASI_RUNNER=wasmtime cargo test --target=wasm32-wasi

I am not good with GitHub CI, but I can take a look to see if it is possible to add such as a step.

@kali
Copy link
Collaborator

kali commented May 28, 2024

What should we add to the CI to get the new tests to run?

Good question! I think we would need to install wasmtime and run in tract/linalg:

RUSTFLAGS='-C target-feature=+simd128' CARGO_TARGET_WASM32_WASI_RUNNER=wasmtime cargo test --target=wasm32-wasi

I am not good with GitHub CI, but I can take a look to see if it is possible to add such as a step.

I can try and do that. Working on the CI from a fork is cumbersome.

@kali
Copy link
Collaborator

kali commented May 28, 2024

So the tests are in place, but as you can see, we have an issue here. (Ignore the problem on nightly, this is a cargo bug that is being fixed).

https://github.com/sonos/tract/actions/runs/9264996443/job/25486167296?pr=1420

I checked, and the issue was there without the kernel, so I'm not sure how to proceed, it looks like we're getting into an unreachable!() in wasmtime. Any chance you can help ? (If this is non-trivial, we could move this investigation to a separate PR, fix the tests and rebase this PR.)

@ulan
Copy link
Contributor Author

ulan commented May 28, 2024

Thanks a lot for adding the CI jobs! Let me try to reproduce the failure locally.

@ulan
Copy link
Contributor Author

ulan commented May 28, 2024

The tests should pass now. The wasmtime unreachable was due to a panic in rust code. (I saw that after running with -- --nocapture).

@ulan
Copy link
Contributor Author

ulan commented Jun 2, 2024

@kali: please take another look when you have time. The PR is ready to go from my side.

@kali
Copy link
Collaborator

kali commented Jun 2, 2024

It does look pretty good. Thanks a lot for contributing this!

@kali kali merged commit 2a2914a into sonos:main Jun 2, 2024
47 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants