Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Wasm SIMD implementation of MatMatMulKer<f32> #1420

Merged
merged 12 commits into from
Jun 2, 2024
9 changes: 9 additions & 0 deletions .travis/cross.sh
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,15 @@ case "$PLATFORM" in
cargo dinghy --platform $PLATFORM build --release -p tract -p example-tensorflow-mobilenet-v2
;;

wasm32-wasi)
rustup target add $PLATFORM
cargo check --target $PLATFORM --features getrandom-js -p tract-onnx -p tract-tensorflow
curl https://wasmtime.dev/install.sh -sSf | bash
WASMTIME=$HOME/.wasmtime/bin/wasmtime
$WASMTIME --version
RUSTFLAGS='-C target-feature=+simd128' CARGO_TARGET_WASM32_WASI_RUNNER=$WASMTIME \
cargo test --target=wasm32-wasi -p tract-linalg -p tract-core -p test-unit-core
;;
wasm32-*)
rustup target add $PLATFORM
cargo check --target $PLATFORM --features getrandom-js -p tract-onnx -p tract-tensorflow
Expand Down
12 changes: 10 additions & 2 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,16 @@ openblas = [ "blas", "openblas-src" ]
paranoid_assertions = []

[dev-dependencies]
criterion.workspace = true
env_logger.workspace = true
lazy_static.workspace = true
proptest.workspace = true
approx.workspace = true

[target.'cfg(not(target_family = "wasm"))'.dev-dependencies]
criterion.workspace = true
proptest.workspace = true

[target.'cfg(target_family = "wasm")'.dev-dependencies]
# Wasm doesn't support the `rayon` feature of criterion
criterion = { version = "0.4", default-features = false, features = ["plotters", "cargo_bench_support"] }
# Wasm doesn't support the `fork` feature of proptest.
proptest = { version = "1.0.0", default-features = false, features = ["std", "bit-set"] }
2 changes: 2 additions & 0 deletions core/src/model/order.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ mod tests {
assert_eq!(model.eval_order_opt_ram().unwrap(), vec!(a.node, add.node));
}

// The test is disabled on Wasm because it uses threads.
#[cfg(not(target_family = "wasm"))]
#[test]
fn dodge_loop() {
let mut model = TypedModel::default();
Expand Down
8 changes: 7 additions & 1 deletion data/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,16 @@ lazy_static.workspace = true
scan_fmt.workspace = true
string-interner.workspace = true

[dev-dependencies]
[target.'cfg(not(target_family = "wasm"))'.dev_dependencies]
criterion.workspace = true
proptest.workspace = true

[target.'cfg(target_family = "wasm")'.dev_dependencies]
# Wasm doesn't support the `rayon` feature of criterion
criterion = { version = "0.4", default-features = false, features = ["plotters", "cargo_bench_support"] }
# Wasm doesn't support the `fork` feature of proptest.
proptest = { version = "1.0.0", default-features = false, features = ["std", "bit-set"] }

[features]
complex = [ "num-complex" ]

Expand Down
12 changes: 10 additions & 2 deletions linalg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,20 @@ time.workspace = true
walkdir.workspace = true

[dev-dependencies]
criterion.workspace = true
env_logger.workspace = true
nu-ansi-term.workspace = true
proptest.workspace = true
core_affinity.workspace = true

[target.'cfg(not(target_family = "wasm"))'.dev-dependencies]
criterion.workspace = true
proptest.workspace = true

[target.'cfg(target_family = "wasm")'.dev-dependencies]
# Wasm doesn't support the `rayon` feature of criterion
criterion = { version = "0.4", default-features = false, features = ["plotters", "cargo_bench_support"] }
# Wasm doesn't support the `fork` feature of proptest.
proptest = { version = "1.0.0", default-features = false, features = ["std", "bit-set"] }

[features]
# This feature is meant to accomodate very restrictive / legacy toolchains that do
# have support for fp16 instructions, breaking tract compilation.
Expand Down
2 changes: 1 addition & 1 deletion linalg/src/frame/mmm/fuse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ pub mod test {
scalar,
|a, b| if b > <$ti>::zero() { b } else { a * b },
<$ker as MatMatMulKer<$ti>>::can_fuse(&FusedSpec::LeakyRelu(&tensor0(
<$ti>::zero()
<$ti>::from(1_u8)
)))
);

Expand Down
12 changes: 10 additions & 2 deletions linalg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ pub mod arm64;
#[cfg(any(target_arch = "arm", target_arch = "armv7"))]
pub mod arm32;

#[cfg(all(target_family = "wasm", target_feature = "simd128"))]
pub mod wasm;

pub use self::frame::{element_wise, lut, mmm};

use crate::frame::mmm::kernel::MatMatMulKer;
Expand Down Expand Up @@ -74,8 +77,10 @@ pub struct Ops {
pub max_f16: Box<dyn Fn() -> Box<dyn reduce::Reduce<f16>> + Send + Sync>,
pub max_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,

pub softmax2_fastcompact_f16: Box<dyn Fn() -> Box<dyn reduce::MapReduce<f16, f16>> + Send + Sync>,
pub softmax2_fastcompact_f32: Box<dyn Fn() -> Box<dyn reduce::MapReduce<f32, f32>> + Send + Sync>,
pub softmax2_fastcompact_f16:
Box<dyn Fn() -> Box<dyn reduce::MapReduce<f16, f16>> + Send + Sync>,
pub softmax2_fastcompact_f32:
Box<dyn Fn() -> Box<dyn reduce::MapReduce<f32, f32>> + Send + Sync>,
}

impl Ops {
Expand Down Expand Up @@ -154,6 +159,9 @@ pub fn best() -> Ops {
arm32::plug(&mut ops);
#[cfg(target_arch = "aarch64")]
arm64::plug(&mut ops);
#[cfg(all(target_family = "wasm", target_feature = "simd128"))]
wasm::plug(&mut ops);

ops
}

Expand Down
Loading
Loading