From 094c435f519d5ce54cc33d993f5bc5f1986f02f0 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 14 Nov 2024 22:48:39 -0500 Subject: [PATCH] add Metal extension for batched_mul --- .buildkite/pipeline.yml | 21 ++++++++ Project.toml | 3 ++ ext/NNlibMetalExt/NNlibMetalExt.jl | 50 ++++++++++++++++++ test/ext_metal/batched_mul.jl | 82 ++++++++++++++++++++++++++++++ test/ext_metal/runtests.jl | 6 +++ test/runtests.jl | 18 +++++++ 6 files changed, 180 insertions(+) create mode 100644 ext/NNlibMetalExt/NNlibMetalExt.jl create mode 100644 test/ext_metal/batched_mul.jl create mode 100644 test/ext_metal/runtests.jl diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index bc3d7d2b9..de9860532 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -55,6 +55,27 @@ steps: NNLIB_TEST_CPU: "false" JULIA_NUM_THREADS: 4 + - label: ":julia: Julia 1 + Metal GPU" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + agents: + queue: "juliaecosystem" + os: "macos" + arch: "aarch64" + timeout_in_minutes: 180 + env: + NNLIB_TEST_METAL: "true" + NNLIB_TEST_CPU: "false" + JULIA_NUM_THREADS: 4 + - label: "Benchmarks" plugins: - JuliaCI/julia#v1: diff --git a/Project.toml b/Project.toml index 2c44f6a8c..1e277c243 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" [extensions] NNlibAMDGPUExt = "AMDGPU" @@ -27,6 +28,7 @@ NNlibCUDAExt = "CUDA" NNlibEnzymeCoreExt = "EnzymeCore" NNlibFFTWExt = "FFTW" NNlibForwardDiffExt = "ForwardDiff" +NNlibMetalExt = "Metal" [compat] AMDGPU = "0.9.4, 1" @@ -40,6 +42,7 @@ ForwardDiff = "0.10.36" GPUArraysCore = "0.1" KernelAbstractions = "0.9.2" LinearAlgebra = "<0.0.1, 1" +Metal = "1.4.2" Random = "<0.0.1, 1" Statistics = "1" cuDNN = "1" diff --git a/ext/NNlibMetalExt/NNlibMetalExt.jl b/ext/NNlibMetalExt/NNlibMetalExt.jl new file mode 100644 index 000000000..969ef0242 --- /dev/null +++ b/ext/NNlibMetalExt/NNlibMetalExt.jl @@ -0,0 +1,50 @@ +module NNlibMetalExt + +using Metal, NNlib +using NNlib: AbstractRNG # === Random.AbstractRNG + +# Random +NNlib._rng_from_array(::MtlArray) = Metal.MPS.default_rng() + +NNlib._rng_compat_array(rng::Metal.MPS.RNG, A::MtlArray) = nothing +NNlib._rng_compat_array(rng::AbstractRNG, A::MtlArray) = throw(ArgumentError( + "cannot use rng::$(typeof(rng)) with array::MtlArray, only Metal's own RNG type works")) + +# Batched matrix multiplication +function NNlib._batched_gemm!(::Type{<:MtlArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) + eltype(C) <: Complex && @warn "don't trust this on complex arrays!" transA transB + Metal.MPS.matmul!(C, A, B, α, β, transA != 'N', transB != 'N') # transA, transB, α, A, B, β, C) +end + +#= + +help?> Metal.MPS.matmul! + matMulMPS(a::MtlMatrix, b::MtlMatrix, c::MtlMatrix, alpha=1, beta=1, + transpose_left=false, transpose_right=false) + + A MPSMatrixMultiplication kernel thay computes: c = alpha * op(a) * beta * op(b) + beta * C + + This function should not typically be used. Rather, use the normal LinearAlgebra interface with + any MtlArray and it should be accelerated using Metal Performance Shaders. + +=# + +using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans +using Adapt +using Adapt: WrappedArray + +const MetalBatchedAdjoint{T} = BatchedAdjoint{T, <: MtlArray{T}} +const MetalBatchedTranspose{T} = BatchedTranspose{T, <: MtlArray{T}} +const MetalBatchedAdjOrTrans{T} = Union{MetalBatchedAdjoint{T}, MetalBatchedTranspose{T}} +const WrappedMetalBatchedAdjOrTrans{T, N} = WrappedArray{T, N, MetalBatchedAdjOrTrans{T}, MetalBatchedAdjOrTrans{T}} + +Base.print_array(io::IO, b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}) = Base.print_array(io, adapt(Array, b)) +Base._show_nonempty(io::IO, b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}, prefix::String) = Base._show_nonempty(io, adapt(Array, b), prefix) +Base.show_vector(io::IO, b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}, opn, cls) = Base.show_vector(io, adapt(Array, b), opn, cls) + +Base.convert(::Type{T}, b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}) where {T<:Array} = Base.convert(T, adapt(Array, b)) +Base.Array{T, N}(b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}) where {T, N} = Array{T, N}(adapt(Array, b)) +Base.collect(b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}) = collect(adapt(Array, b)) + + +end # module NNlibMetalExt diff --git a/test/ext_metal/batched_mul.jl b/test/ext_metal/batched_mul.jl new file mode 100644 index 000000000..a26a748c2 --- /dev/null +++ b/test/ext_metal/batched_mul.jl @@ -0,0 +1,82 @@ +@testset "batched_mul" begin + using NNlib: batched_mul, batched_mul!, batched_vec, + batched_adjoint, batched_transpose + + A = randn(Float32, 3,3,2); + B = randn(Float32, 3,3,2); + + C = batched_mul(A, B) + @test MtlArray(C) ≈ batched_mul(MtlArray(A), MtlArray(B)) + + Ct = batched_mul(batched_transpose(A), B) + @test MtlArray(Ct) ≈ batched_mul(batched_transpose(MtlArray(A)), MtlArray(B)) + + Ca = batched_mul(A, batched_adjoint(B)) + @test MtlArray(Ca) ≈ batched_mul(MtlArray(A), batched_adjoint(MtlArray(B))) + + # 5-arg batched_mul! + C .= pi + batched_mul!(C, A, B, 2f0, 3f0) + gpuCpi = MtlArray(similar(C)) .= pi + @test MtlArray(C) ≈ batched_mul!(gpuCpi, MtlArray(A), MtlArray(B), 2f0, 3f0) + + # PermutedDimsArray + @test MtlArray(Ct) ≈ batched_mul(PermutedDimsArray(MtlArray(A), (2,1,3)), MtlArray(B)) + + D = permutedims(B, (1,3,2)) + Cp = batched_mul(batched_adjoint(A), B) + @test_broken MtlArray(Cp) ≈ batched_mul(batched_adjoint(MtlArray(A)), PermutedDimsArray(MtlArray(D), (1,3,2))) + + # Methods which reshape + M = randn(Float32, 3,3) + + Cm = batched_mul(A, M) + @test MtlArray(Cm) ≈ batched_mul(MtlArray(A), MtlArray(M)) + + Cv = batched_vec(permutedims(A,(3,1,2)), M) + @test_broken MtlArray(Cv) ≈ batched_vec(PermutedDimsArray(MtlArray(A),(3,1,2)), MtlArray(M)) +end + +function print_array_strs(x) + str = sprint((io, x)->show(io, MIME"text/plain"(), x), x) + return @view split(str, '\n')[2:end] +end + +@testset "BatchedAdjOrTrans" begin + x = rand(Float32, 3, 4, 2) + y = MtlArray(x) + + bax = batched_adjoint(x) + btx = batched_transpose(x) + bay = batched_adjoint(y) + bty = batched_transpose(y) + + @test sprint(show, bax) == sprint(show, bay) + @test sprint(show, btx) == sprint(show, bty) + + @test print_array_strs(bax) == print_array_strs(bay) + @test print_array_strs(btx) == print_array_strs(bty) + + @test Array(bax) == Array(bay) + @test collect(bax) == collect(bay) + @test Array(btx) == Array(bty) + @test collect(btx) == collect(bty) + + for shape in (:, (12, 2)) + rbax = reshape(bax, shape) + rbtx = reshape(btx, shape) + rbay = reshape(bay, shape) + rbty = reshape(bty, shape) + + @test sprint(show, rbax) == sprint(show, rbay) + @test sprint(show, rbtx) == sprint(show, rbty) + + @test print_array_strs(rbax) == print_array_strs(rbay) + @test print_array_strs(rbtx) == print_array_strs(rbty) + + @test Array(rbax) == Array(rbay) + @test collect(rbax) == collect(rbay) + @test Array(rbtx) == Array(rbty) + @test collect(rbtx) == collect(rbty) + end +end diff --git a/test/ext_metal/runtests.jl b/test/ext_metal/runtests.jl new file mode 100644 index 000000000..bd19ee074 --- /dev/null +++ b/test/ext_metal/runtests.jl @@ -0,0 +1,6 @@ + +Metal.allowscalar(false) + +@testset "Batched multiplication" begin + include("batched_mul.jl") +end diff --git a/test/runtests.jl b/test/runtests.jl index 8ceafe405..26a6d8d43 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,7 @@ DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursiv # ENV["NNLIB_TEST_CUDA"] = "true" # uncomment to run CUDA tests # ENV["NNLIB_TEST_AMDGPU"] = "true" # uncomment to run AMDGPU tests +# ENV["NNLIB_TEST_METAL"] = "true" # uncomment to run Metal tests # ENV["NNLIB_TEST_CPU"] = "false" # uncomment to skip CPU tests const rng = StableRNG(123) @@ -174,4 +175,21 @@ end else @info "Skipping AMDGPU tests, set NNLIB_TEST_AMDGPU=true to run them." end + + if get(ENV, "NNLIB_TEST_METAL", "false") == "true" + Pkg.add(["Metal"]) + + using Metal + if Metal.functional() + @testset "Metal" begin + # nnlib_testsuite(CUDABackend; skip_tests=Set(("Scatter", "Gather"))) + + include("ext_metal/runtests.jl") + end + else + @info "Metal.jl package is not functional. Skipping Metal tests." + end + else + @info "Skipping Metal tests, set NNLIB_TEST_METAL=true to run them" + end end