From 801742803a81a1e1bfe735472a1ca01d525360bd Mon Sep 17 00:00:00 2001 From: Ronan Arraes Jardim Chagas Date: Wed, 29 May 2024 17:03:57 -0300 Subject: [PATCH 1/5] Add direct calls to BLAS to compute SVD --- src/StaticArrays.jl | 2 +- src/svd.jl | 171 ++++++++++++++++++++++++++++++++++++++++++++ test/svd.jl | 24 ++++++- 3 files changed, 193 insertions(+), 4 deletions(-) diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 0e4b1afe..c2f2c9ff 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -17,7 +17,7 @@ import LinearAlgebra: transpose, adjoint, dot, eigvals, eigen, lyap, tr, kron, diag, norm, dot, diagm, lu, svd, svdvals, pinv, factorize, ishermitian, issymmetric, isposdef, issuccess, normalize, normalize!, Eigen, det, logdet, logabsdet, cross, diff, qr, \ -using LinearAlgebra: checksquare +using LinearAlgebra: BLAS, checksquare, LAPACK, libblastrampoline using PrecompileTools diff --git a/src/svd.jl b/src/svd.jl index d2af6e90..f2767119 100644 --- a/src/svd.jl +++ b/src/svd.jl @@ -32,6 +32,88 @@ function svdvals(A::StaticMatrix) similar_type(A, T2, Size(diagsize(A)))(sv) end +# Implement direct call to BLAS functions that computes the SVD values for `SMatrix` and +# `MMatrix` reducing allocations. In this case, we use `MMatrix` to call the library and +# convert the result back to the input type. Since the former does not exit this scope, we +# can reduce allocations. +# +# We are implementing here the following functions: +# +# svdvals(A::SMatrix{M, N, Float64}) where {M, N} +# svdvals(A::SMatrix{M, N, Float32}) where {M, N} +# svdvals(A::MMatrix{M, N, Float64}) where {M, N} +# svdvals(A::MMatrix{M, N, Float32}) where {M, N} +# +for (gesdd, elty) in ((:dgesdd_, :Float64), (:sgesdd_, :Float32)), + (mtype, vtype) in ((SMatrix, SVector), (MMatrix, MVector)) + + @eval begin + function svdvals(A::$mtype{M, N, $elty}) where {M, N} + K = min(M, N) + + # Convert the input to a `MMatrix` and allocate the required arrays. + Am = MMatrix{M, N, $elty}(A) + Sm = MVector{K, $elty}(undef) + + # We compute the `lwork` (size of the work array) by obtaining the maximum value + # from the possibilities shown in: + # https://docs.oracle.com/cd/E19422-01/819-3691/dgesdd.html + lwork = max(8N, 3N + max(M, 7N), 8M, 3M + max(N, 7M)) + work = MVector{lwork, $elty}(undef) + iwork = MVector{8min(M, N), BLAS.BlasInt}(undef) + info = Ref(1) + + ccall( + (BLAS.@blasfunc($gesdd), libblastrampoline), + Cvoid, + ( + Ref{UInt8}, + Ref{BLAS.BlasInt}, + Ref{BLAS.BlasInt}, + Ptr{$elty}, + Ref{BLAS.BlasInt}, + Ptr{$elty}, + Ptr{C_NULL}, + Ref{BLAS.BlasInt}, + Ptr{C_NULL}, + Ref{BLAS.BlasInt}, + Ptr{$elty}, + Ref{BLAS.BlasInt}, + Ptr{BLAS.BlasInt}, + Ptr{BLAS.BlasInt}, + Clong + ), + 'N', + M, + N, + Am, + M, + Sm, + C_NULL, + M, + C_NULL, + K, + work, + lwork, + iwork, + info, + 1 + ) + + # Check if the return result of the function. + LAPACK.chklapackerror(info.x) + + # Convert the vector to static arrays and return. + S = $vtype{K, $elty}(Sm) + + return S + end + end +end + +# For matrices with interger numbers, we should promote them to float and call `svdvals`. +@inline svdvals(A::StaticMatrix{<: Any, <: Any, <: Integer}) = svdvals(float(A)) + # `@inline` annotation is required to propagate `full` as constant to `_svd` @inline svd(A::StaticMatrix; full=Val(false)) = _svd(A, full) @@ -56,6 +138,94 @@ function _svd(A, full::Val{true}) SVD(U,S,Vt) end +# Implement direct call to BLAS functions that computes the SVD for `SMatrix` and `MMatrix` +# reducing allocations. In this case, we use `MMatrix` to call the library and convert the +# result back to the input type. Since the former does not exit this scope, we can reduce +# allocations. +# +# We are implementing here the following functions: +# +# _svd(A::SMatrix{M, N, Float64}, full::Val{false}) where {M, N} +# _svd(A::SMatrix{M, N, Float64}, full::Val{true}) where {M, N} +# _svd(A::SMatrix{M, N, Float32}, full::Val{false}) where {M, N} +# _svd(A::SMatrix{M, N, Float32}, full::Val{true}) where {M, N} +# _svd(A::MMatrix{M, N, Float64}, full::Val{false}) where {M, N} +# _svd(A::MMatrix{M, N, Float64}, full::Val{true}) where {M, N} +# _svd(A::MMatrix{M, N, Float32}, full::Val{false}) where {M, N} +# _svd(A::MMatrix{M, N, Float32}, full::Val{true}) where {M, N} +# +for (gesvd, elty) in ((:dgesvd_, :Float64), (:sgesvd_, :Float32)), + full in (false, true), + (mtype, vtype) in ((SMatrix, SVector), (MMatrix, MVector)) + + @eval begin + function _svd(A::$mtype{M, N, $elty}, full::Val{$full}) where {M, N} + K = min(M, N) + + # Convert the input to a `MMatrix` and allocate the required arrays. + Am = MMatrix{M, N, $elty}(A) + Um = MMatrix{M, $(full ? :M : :K), $elty}(undef) + Sm = MVector{K, $elty}(undef) + Vtm = MMatrix{$(full ? :N : :K), N, $elty}(undef) + lwork = max(3min(M, N) + max(M, N), 5min(M, N)) + work = MVector{lwork, $elty}(undef) + info = Ref(1) + + ccall( + (BLAS.@blasfunc($gesvd), libblastrampoline), + Cvoid, + ( + Ref{UInt8}, + Ref{UInt8}, + Ref{BLAS.BlasInt}, + Ref{BLAS.BlasInt}, + Ptr{$elty}, + Ref{BLAS.BlasInt}, + Ptr{$elty}, + Ptr{$elty}, + Ref{BLAS.BlasInt}, + Ptr{$elty}, + Ref{BLAS.BlasInt}, + Ptr{$elty}, + Ref{BLAS.BlasInt}, + Ptr{BLAS.BlasInt}, + Clong, + Clong + ), + $(full ? 'A' : 'S'), + $(full ? 'A' : 'S'), + M, + N, + Am, + M, + Sm, + Um, + M, + Vtm, + $(full ? :N : :K), + work, + lwork, + info, + 1, + 1 + ) + + # Check if the return result of the function. + LAPACK.chklapackerror(info.x) + + # Convert the matrices to the correct type and return. + U = $mtype{M, $(full ? :M : :K), $elty}(Um) + S = $vtype{K, $elty}(Sm) + Vt = $mtype{$(full ? :N : :K), N, $elty}(Vtm) + + return SVD(U, S, Vt) + end + end +end + +# For matrices with interger numbers, we should promote them to float and call `svd`. +@inline svd(A::StaticMatrix{<: Any, <: Any, <: Integer}) = svd(float(A)) + function \(F::SVD, B::StaticVecOrMat) sthresh = eps(F.S[1]) Sinv = map(s->s < sthresh ? zero(1/sthresh) : 1/s, F.S) @@ -73,3 +243,4 @@ function diagmult(sd, sB, d, B) ind = SOneTo(sd[1]) return isa(B, AbstractVector) ? Diagonal(d)*B[ind] : Diagonal(d)*B[ind,:] end + diff --git a/test/svd.jl b/test/svd.jl index 0df776f0..d377dc3e 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -2,8 +2,10 @@ using StaticArrays, Test, LinearAlgebra @testset "SVD factorization" begin m3 = @SMatrix Float64[3 9 4; 6 6 2; 3 7 9] + m3_f32 = @SMatrix Float32[3 9 4; 6 6 2; 3 7 9] m3c = ComplexF64.(m3) m23 = @SMatrix Float64[3 9 4; 6 6 2] + m23_f32 = @SMatrix Float32[3 9 4; 6 6 2] m_sing = @SMatrix [2.0 3.0 5.0; 4.0 6.0 10.0; 1.0 1.0 1.0] m_sing2 = @SMatrix [1 1; 1 0; 0 1] v = @SVector [1, 2, 3] @@ -18,6 +20,7 @@ using StaticArrays, Test, LinearAlgebra @testinf svdvals((@SMatrix [2 -2; 1 1]) / sqrt(2)) ≊ [2, 1] @testinf svdvals(m3) ≊ svdvals(Matrix(m3)) + @testinf svdvals(m3_f32) ≊ svdvals(Matrix(m3_f32)) @testinf svdvals(m3c) isa SVector{3,Float64} @testinf svd(m3).U::StaticMatrix ≊ svd(Matrix(m3)).U @@ -25,9 +28,14 @@ using StaticArrays, Test, LinearAlgebra @testinf svd(m3).V::StaticMatrix ≊ svd(Matrix(m3)).V @testinf svd(m3).Vt::StaticMatrix ≊ svd(Matrix(m3)).Vt - @testinf svd(@SMatrix [2 0; 0 0]).U === one(SMatrix{2,2}) - @testinf svd(@SMatrix [2 0; 0 0]).S === SVector(2.0, 0.0) - @testinf svd(@SMatrix [2 0; 0 0]).Vt === one(SMatrix{2,2}) + @test svd(m3_f32).U::StaticMatrix ≈ svd(Matrix(m3_f32)).U atol = 5e-7 + @test svd(m3_f32).S::StaticVector ≈ svd(Matrix(m3_f32)).S atol = 5e-7 + @test svd(m3_f32).V::StaticMatrix ≈ svd(Matrix(m3_f32)).V atol = 5e-7 + @test svd(m3_f32).Vt::StaticMatrix ≈ svd(Matrix(m3_f32)).Vt atol = 5e-7 + + @testinf svd(@SMatrix [2 0; 0 0]).U ≊ one(SMatrix{2,2}) + @testinf svd(@SMatrix [2 0; 0 0]).S ≊ SVector(2.0, 0.0) + @testinf svd(@SMatrix [2 0; 0 0]).Vt ≊ one(SMatrix{2,2}) @testinf svd((@SMatrix [2 -2; 1 1]) / sqrt(2)).U ≊ [-1 0; 0 1] @testinf svd((@SMatrix [2 -2; 1 1]) / sqrt(2)).S ≊ [2, 1] @@ -41,6 +49,16 @@ using StaticArrays, Test, LinearAlgebra @testinf svd(m23').S ≊ svd(Matrix(m23')).S @testinf svd(m23').Vt ≊ svd(Matrix(m23')).Vt + @test svd(m23_f32).U::StaticMatrix ≈ svd(Matrix(m23_f32)).U atol = 5e-7 + @test svd(m23_f32).S::StaticVector ≈ svd(Matrix(m23_f32)).S atol = 5e-7 + @test svd(m23_f32).V::StaticMatrix ≈ svd(Matrix(m23_f32)).V atol = 5e-7 + @test svd(m23_f32).Vt::StaticMatrix ≈ svd(Matrix(m23_f32)).Vt atol = 5e-7 + + @test svd(m23_f32').U::StaticMatrix ≈ svd(Matrix(m23_f32')).U atol = 5e-7 + @test svd(m23_f32').S::StaticVector ≈ svd(Matrix(m23_f32')).S atol = 5e-7 + @test svd(m23_f32').V::StaticMatrix ≈ svd(Matrix(m23_f32')).V atol = 5e-7 + @test svd(m23_f32').Vt::StaticMatrix ≈ svd(Matrix(m23_f32')).Vt atol = 5e-7 + @testinf svd(m23, full=true).U::StaticMatrix ≊ svd(Matrix(m23), full=true).U @testinf svd(m23, full=true).S::StaticVector ≊ svd(Matrix(m23), full=true).S @testinf svd(m23, full=true).Vt::StaticMatrix ≊ svd(Matrix(m23), full=true).Vt From 3f563297902f4a999db16277ed06cbf560aa9c90 Mon Sep 17 00:00:00 2001 From: Ronan Arraes Jardim Chagas Date: Wed, 29 May 2024 18:01:32 -0300 Subject: [PATCH 2/5] Move funcs with direct BLAS interface to new file --- src/StaticArrays.jl | 1 + src/blas.jl | 174 ++++++++++++++++++++++++++++++++++++++++++++ src/svd.jl | 170 ------------------------------------------- 3 files changed, 175 insertions(+), 170 deletions(-) create mode 100644 src/blas.jl diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index c2f2c9ff..8f2d80d4 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -109,6 +109,7 @@ include("convert.jl") include("abstractarray.jl") include("indexing.jl") +include("blas.jl") include("broadcast.jl") include("mapreduce.jl") include("sort.jl") diff --git a/src/blas.jl b/src/blas.jl new file mode 100644 index 00000000..36c31389 --- /dev/null +++ b/src/blas.jl @@ -0,0 +1,174 @@ +# This file contains funtions that uses a direct interface to BLAS library. We use this +# approach to reduce allocations. + +# == Singular Value Decomposition ========================================================== + +# Implement direct call to BLAS functions that computes the SVD values for `SMatrix` and +# `MMatrix` reducing allocations. In this case, we use `MMatrix` to call the library and +# convert the result back to the input type. Since the former does not exit this scope, we +# can reduce allocations. +# +# We are implementing here the following functions: +# +# svdvals(A::SMatrix{M, N, Float64}) where {M, N} +# svdvals(A::SMatrix{M, N, Float32}) where {M, N} +# svdvals(A::MMatrix{M, N, Float64}) where {M, N} +# svdvals(A::MMatrix{M, N, Float32}) where {M, N} +# +for (gesdd, elty) in ((:dgesdd_, :Float64), (:sgesdd_, :Float32)), + (mtype, vtype) in ((SMatrix, SVector), (MMatrix, MVector)) + + @eval begin + function svdvals(A::$mtype{M, N, $elty}) where {M, N} + K = min(M, N) + + # Convert the input to a `MMatrix` and allocate the required arrays. + Am = MMatrix{M, N, $elty}(A) + Sm = MVector{K, $elty}(undef) + + # We compute the `lwork` (size of the work array) by obtaining the maximum value + # from the possibilities shown in: + # https://docs.oracle.com/cd/E19422-01/819-3691/dgesdd.html + lwork = max(8N, 3N + max(M, 7N), 8M, 3M + max(N, 7M)) + work = MVector{lwork, $elty}(undef) + iwork = MVector{8min(M, N), BLAS.BlasInt}(undef) + info = Ref(1) + + ccall( + (BLAS.@blasfunc($gesdd), libblastrampoline), + Cvoid, + ( + Ref{UInt8}, + Ref{BLAS.BlasInt}, + Ref{BLAS.BlasInt}, + Ptr{$elty}, + Ref{BLAS.BlasInt}, + Ptr{$elty}, + Ptr{C_NULL}, + Ref{BLAS.BlasInt}, + Ptr{C_NULL}, + Ref{BLAS.BlasInt}, + Ptr{$elty}, + Ref{BLAS.BlasInt}, + Ptr{BLAS.BlasInt}, + Ptr{BLAS.BlasInt}, + Clong + ), + 'N', + M, + N, + Am, + M, + Sm, + C_NULL, + M, + C_NULL, + K, + work, + lwork, + iwork, + info, + 1 + ) + + # Check if the return result of the function. + LAPACK.chklapackerror(info.x) + + # Convert the vector to static arrays and return. + S = $vtype{K, $elty}(Sm) + + return S + end + end +end + +# For matrices with interger numbers, we should promote them to float and call `svdvals`. +@inline svdvals(A::StaticMatrix{<: Any, <: Any, <: Integer}) = svdvals(float(A)) + +# Implement direct call to BLAS functions that computes the SVD for `SMatrix` and `MMatrix` +# reducing allocations. In this case, we use `MMatrix` to call the library and convert the +# result back to the input type. Since the former does not exit this scope, we can reduce +# allocations. +# +# We are implementing here the following functions: +# +# _svd(A::SMatrix{M, N, Float64}, full::Val{false}) where {M, N} +# _svd(A::SMatrix{M, N, Float64}, full::Val{true}) where {M, N} +# _svd(A::SMatrix{M, N, Float32}, full::Val{false}) where {M, N} +# _svd(A::SMatrix{M, N, Float32}, full::Val{true}) where {M, N} +# _svd(A::MMatrix{M, N, Float64}, full::Val{false}) where {M, N} +# _svd(A::MMatrix{M, N, Float64}, full::Val{true}) where {M, N} +# _svd(A::MMatrix{M, N, Float32}, full::Val{false}) where {M, N} +# _svd(A::MMatrix{M, N, Float32}, full::Val{true}) where {M, N} +# +for (gesvd, elty) in ((:dgesvd_, :Float64), (:sgesvd_, :Float32)), + full in (false, true), + (mtype, vtype) in ((SMatrix, SVector), (MMatrix, MVector)) + + @eval begin + function _svd(A::$mtype{M, N, $elty}, full::Val{$full}) where {M, N} + K = min(M, N) + + # Convert the input to a `MMatrix` and allocate the required arrays. + Am = MMatrix{M, N, $elty}(A) + Um = MMatrix{M, $(full ? :M : :K), $elty}(undef) + Sm = MVector{K, $elty}(undef) + Vtm = MMatrix{$(full ? :N : :K), N, $elty}(undef) + lwork = max(3min(M, N) + max(M, N), 5min(M, N)) + work = MVector{lwork, $elty}(undef) + info = Ref(1) + + ccall( + (BLAS.@blasfunc($gesvd), libblastrampoline), + Cvoid, + ( + Ref{UInt8}, + Ref{UInt8}, + Ref{BLAS.BlasInt}, + Ref{BLAS.BlasInt}, + Ptr{$elty}, + Ref{BLAS.BlasInt}, + Ptr{$elty}, + Ptr{$elty}, + Ref{BLAS.BlasInt}, + Ptr{$elty}, + Ref{BLAS.BlasInt}, + Ptr{$elty}, + Ref{BLAS.BlasInt}, + Ptr{BLAS.BlasInt}, + Clong, + Clong + ), + $(full ? 'A' : 'S'), + $(full ? 'A' : 'S'), + M, + N, + Am, + M, + Sm, + Um, + M, + Vtm, + $(full ? :N : :K), + work, + lwork, + info, + 1, + 1 + ) + + # Check if the return result of the function. + LAPACK.chklapackerror(info.x) + + # Convert the matrices to the correct type and return. + U = $mtype{M, $(full ? :M : :K), $elty}(Um) + S = $vtype{K, $elty}(Sm) + Vt = $mtype{$(full ? :N : :K), N, $elty}(Vtm) + + return SVD(U, S, Vt) + end + end +end + +# For matrices with interger numbers, we should promote them to float and call `svd`. +@inline svd(A::StaticMatrix{<: Any, <: Any, <: Integer}) = svd(float(A)) diff --git a/src/svd.jl b/src/svd.jl index f2767119..b8d1e0e2 100644 --- a/src/svd.jl +++ b/src/svd.jl @@ -32,88 +32,6 @@ function svdvals(A::StaticMatrix) similar_type(A, T2, Size(diagsize(A)))(sv) end -# Implement direct call to BLAS functions that computes the SVD values for `SMatrix` and -# `MMatrix` reducing allocations. In this case, we use `MMatrix` to call the library and -# convert the result back to the input type. Since the former does not exit this scope, we -# can reduce allocations. -# -# We are implementing here the following functions: -# -# svdvals(A::SMatrix{M, N, Float64}) where {M, N} -# svdvals(A::SMatrix{M, N, Float32}) where {M, N} -# svdvals(A::MMatrix{M, N, Float64}) where {M, N} -# svdvals(A::MMatrix{M, N, Float32}) where {M, N} -# -for (gesdd, elty) in ((:dgesdd_, :Float64), (:sgesdd_, :Float32)), - (mtype, vtype) in ((SMatrix, SVector), (MMatrix, MVector)) - - @eval begin - function svdvals(A::$mtype{M, N, $elty}) where {M, N} - K = min(M, N) - - # Convert the input to a `MMatrix` and allocate the required arrays. - Am = MMatrix{M, N, $elty}(A) - Sm = MVector{K, $elty}(undef) - - # We compute the `lwork` (size of the work array) by obtaining the maximum value - # from the possibilities shown in: - # https://docs.oracle.com/cd/E19422-01/819-3691/dgesdd.html - lwork = max(8N, 3N + max(M, 7N), 8M, 3M + max(N, 7M)) - work = MVector{lwork, $elty}(undef) - iwork = MVector{8min(M, N), BLAS.BlasInt}(undef) - info = Ref(1) - - ccall( - (BLAS.@blasfunc($gesdd), libblastrampoline), - Cvoid, - ( - Ref{UInt8}, - Ref{BLAS.BlasInt}, - Ref{BLAS.BlasInt}, - Ptr{$elty}, - Ref{BLAS.BlasInt}, - Ptr{$elty}, - Ptr{C_NULL}, - Ref{BLAS.BlasInt}, - Ptr{C_NULL}, - Ref{BLAS.BlasInt}, - Ptr{$elty}, - Ref{BLAS.BlasInt}, - Ptr{BLAS.BlasInt}, - Ptr{BLAS.BlasInt}, - Clong - ), - 'N', - M, - N, - Am, - M, - Sm, - C_NULL, - M, - C_NULL, - K, - work, - lwork, - iwork, - info, - 1 - ) - - # Check if the return result of the function. - LAPACK.chklapackerror(info.x) - - # Convert the vector to static arrays and return. - S = $vtype{K, $elty}(Sm) - - return S - end - end -end - -# For matrices with interger numbers, we should promote them to float and call `svdvals`. -@inline svdvals(A::StaticMatrix{<: Any, <: Any, <: Integer}) = svdvals(float(A)) - # `@inline` annotation is required to propagate `full` as constant to `_svd` @inline svd(A::StaticMatrix; full=Val(false)) = _svd(A, full) @@ -138,94 +56,6 @@ function _svd(A, full::Val{true}) SVD(U,S,Vt) end -# Implement direct call to BLAS functions that computes the SVD for `SMatrix` and `MMatrix` -# reducing allocations. In this case, we use `MMatrix` to call the library and convert the -# result back to the input type. Since the former does not exit this scope, we can reduce -# allocations. -# -# We are implementing here the following functions: -# -# _svd(A::SMatrix{M, N, Float64}, full::Val{false}) where {M, N} -# _svd(A::SMatrix{M, N, Float64}, full::Val{true}) where {M, N} -# _svd(A::SMatrix{M, N, Float32}, full::Val{false}) where {M, N} -# _svd(A::SMatrix{M, N, Float32}, full::Val{true}) where {M, N} -# _svd(A::MMatrix{M, N, Float64}, full::Val{false}) where {M, N} -# _svd(A::MMatrix{M, N, Float64}, full::Val{true}) where {M, N} -# _svd(A::MMatrix{M, N, Float32}, full::Val{false}) where {M, N} -# _svd(A::MMatrix{M, N, Float32}, full::Val{true}) where {M, N} -# -for (gesvd, elty) in ((:dgesvd_, :Float64), (:sgesvd_, :Float32)), - full in (false, true), - (mtype, vtype) in ((SMatrix, SVector), (MMatrix, MVector)) - - @eval begin - function _svd(A::$mtype{M, N, $elty}, full::Val{$full}) where {M, N} - K = min(M, N) - - # Convert the input to a `MMatrix` and allocate the required arrays. - Am = MMatrix{M, N, $elty}(A) - Um = MMatrix{M, $(full ? :M : :K), $elty}(undef) - Sm = MVector{K, $elty}(undef) - Vtm = MMatrix{$(full ? :N : :K), N, $elty}(undef) - lwork = max(3min(M, N) + max(M, N), 5min(M, N)) - work = MVector{lwork, $elty}(undef) - info = Ref(1) - - ccall( - (BLAS.@blasfunc($gesvd), libblastrampoline), - Cvoid, - ( - Ref{UInt8}, - Ref{UInt8}, - Ref{BLAS.BlasInt}, - Ref{BLAS.BlasInt}, - Ptr{$elty}, - Ref{BLAS.BlasInt}, - Ptr{$elty}, - Ptr{$elty}, - Ref{BLAS.BlasInt}, - Ptr{$elty}, - Ref{BLAS.BlasInt}, - Ptr{$elty}, - Ref{BLAS.BlasInt}, - Ptr{BLAS.BlasInt}, - Clong, - Clong - ), - $(full ? 'A' : 'S'), - $(full ? 'A' : 'S'), - M, - N, - Am, - M, - Sm, - Um, - M, - Vtm, - $(full ? :N : :K), - work, - lwork, - info, - 1, - 1 - ) - - # Check if the return result of the function. - LAPACK.chklapackerror(info.x) - - # Convert the matrices to the correct type and return. - U = $mtype{M, $(full ? :M : :K), $elty}(Um) - S = $vtype{K, $elty}(Sm) - Vt = $mtype{$(full ? :N : :K), N, $elty}(Vtm) - - return SVD(U, S, Vt) - end - end -end - -# For matrices with interger numbers, we should promote them to float and call `svd`. -@inline svd(A::StaticMatrix{<: Any, <: Any, <: Integer}) = svd(float(A)) - function \(F::SVD, B::StaticVecOrMat) sthresh = eps(F.S[1]) Sinv = map(s->s < sthresh ? zero(1/sthresh) : 1/s, F.S) From 97563978a48cf0f0c331fefc92db953c1bc95368 Mon Sep 17 00:00:00 2001 From: Ronan Arraes Jardim Chagas Date: Wed, 29 May 2024 19:49:08 -0300 Subject: [PATCH 3/5] Use @ccall --- src/blas.jl | 113 ++++++++++++++++++---------------------------------- 1 file changed, 39 insertions(+), 74 deletions(-) diff --git a/src/blas.jl b/src/blas.jl index 36c31389..5272f946 100644 --- a/src/blas.jl +++ b/src/blas.jl @@ -18,6 +18,8 @@ for (gesdd, elty) in ((:dgesdd_, :Float64), (:sgesdd_, :Float32)), (mtype, vtype) in ((SMatrix, SVector), (MMatrix, MVector)) + blas_func = @eval BLAS.@blasfunc($gesdd) + @eval begin function svdvals(A::$mtype{M, N, $elty}) where {M, N} K = min(M, N) @@ -34,42 +36,23 @@ for (gesdd, elty) in ((:dgesdd_, :Float64), (:sgesdd_, :Float32)), iwork = MVector{8min(M, N), BLAS.BlasInt}(undef) info = Ref(1) - ccall( - (BLAS.@blasfunc($gesdd), libblastrampoline), - Cvoid, - ( - Ref{UInt8}, - Ref{BLAS.BlasInt}, - Ref{BLAS.BlasInt}, - Ptr{$elty}, - Ref{BLAS.BlasInt}, - Ptr{$elty}, - Ptr{C_NULL}, - Ref{BLAS.BlasInt}, - Ptr{C_NULL}, - Ref{BLAS.BlasInt}, - Ptr{$elty}, - Ref{BLAS.BlasInt}, - Ptr{BLAS.BlasInt}, - Ptr{BLAS.BlasInt}, - Clong - ), - 'N', - M, - N, - Am, - M, - Sm, - C_NULL, - M, - C_NULL, - K, - work, - lwork, - iwork, - info, - 1 - ) + @ccall libblastrampoline.$blas_func( + 'N'::Ref{UInt8}, + M::Ref{BLAS.BlasInt}, + N::Ref{BLAS.BlasInt}, + Am::Ptr{$elty}, + M::Ref{BLAS.BlasInt}, + Sm::Ptr{$elty}, + C_NULL::Ptr{C_NULL}, + M::Ref{BLAS.BlasInt}, + C_NULL::Ptr{C_NULL}, + K::Ref{BLAS.BlasInt}, + work::Ptr{$elty}, + lwork::Ref{BLAS.BlasInt}, + iwork::Ptr{BLAS.BlasInt}, + info::Ptr{BLAS.BlasInt}, + 1::Clong + )::Cvoid # Check if the return result of the function. LAPACK.chklapackerror(info.x) @@ -105,6 +88,8 @@ for (gesvd, elty) in ((:dgesvd_, :Float64), (:sgesvd_, :Float32)), full in (false, true), (mtype, vtype) in ((SMatrix, SVector), (MMatrix, MVector)) + blas_func = @eval BLAS.@blasfunc($gesvd) + @eval begin function _svd(A::$mtype{M, N, $elty}, full::Val{$full}) where {M, N} K = min(M, N) @@ -118,44 +103,24 @@ for (gesvd, elty) in ((:dgesvd_, :Float64), (:sgesvd_, :Float32)), work = MVector{lwork, $elty}(undef) info = Ref(1) - ccall( - (BLAS.@blasfunc($gesvd), libblastrampoline), - Cvoid, - ( - Ref{UInt8}, - Ref{UInt8}, - Ref{BLAS.BlasInt}, - Ref{BLAS.BlasInt}, - Ptr{$elty}, - Ref{BLAS.BlasInt}, - Ptr{$elty}, - Ptr{$elty}, - Ref{BLAS.BlasInt}, - Ptr{$elty}, - Ref{BLAS.BlasInt}, - Ptr{$elty}, - Ref{BLAS.BlasInt}, - Ptr{BLAS.BlasInt}, - Clong, - Clong - ), - $(full ? 'A' : 'S'), - $(full ? 'A' : 'S'), - M, - N, - Am, - M, - Sm, - Um, - M, - Vtm, - $(full ? :N : :K), - work, - lwork, - info, - 1, - 1 - ) + @ccall libblastrampoline.$blas_func( + $(full ? 'A' : 'S')::Ref{UInt8}, + $(full ? 'A' : 'S')::Ref{UInt8}, + M::Ref{BLAS.BlasInt}, + N::Ref{BLAS.BlasInt}, + Am::Ptr{$elty}, + M::Ref{BLAS.BlasInt}, + Sm::Ptr{$elty}, + Um::Ptr{$elty}, + M::Ref{BLAS.BlasInt}, + Vtm::Ptr{$elty}, + $(full ? :N : :K)::Ref{BLAS.BlasInt}, + work::Ptr{$elty}, + lwork::Ref{BLAS.BlasInt}, + info::Ptr{BLAS.BlasInt}, + 1::Clong, + 1::Clong + )::Cvoid # Check if the return result of the function. LAPACK.chklapackerror(info.x) From 3f324e64e797628fc11ce2ae1e3fc6aa2de14876 Mon Sep 17 00:00:00 2001 From: Ronan Arraes Jardim Chagas Date: Wed, 5 Jun 2024 10:55:12 -0300 Subject: [PATCH 4/5] Restrict BLAS interface to Julia v1.7 or higher --- src/StaticArrays.jl | 7 +++++-- src/blas.jl | 2 ++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 8f2d80d4..c4a0437e 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -17,7 +17,7 @@ import LinearAlgebra: transpose, adjoint, dot, eigvals, eigen, lyap, tr, kron, diag, norm, dot, diagm, lu, svd, svdvals, pinv, factorize, ishermitian, issymmetric, isposdef, issuccess, normalize, normalize!, Eigen, det, logdet, logabsdet, cross, diff, qr, \ -using LinearAlgebra: BLAS, checksquare, LAPACK, libblastrampoline +using LinearAlgebra: checksquare using PrecompileTools @@ -109,7 +109,6 @@ include("convert.jl") include("abstractarray.jl") include("indexing.jl") -include("blas.jl") include("broadcast.jl") include("mapreduce.jl") include("sort.jl") @@ -134,6 +133,10 @@ include("flatten.jl") include("io.jl") include("pinv.jl") +@static if VERSION >= v"1.7" + include("blas.jl") +end + @static if !isdefined(Base, :get_extension) # VERSION < v"1.9-" include("../ext/StaticArraysStatisticsExt.jl") end diff --git a/src/blas.jl b/src/blas.jl index 5272f946..e019c2a9 100644 --- a/src/blas.jl +++ b/src/blas.jl @@ -1,6 +1,8 @@ # This file contains funtions that uses a direct interface to BLAS library. We use this # approach to reduce allocations. +import LinearAlgebra: BLAS, LAPACK, libblastrampoline + # == Singular Value Decomposition ========================================================== # Implement direct call to BLAS functions that computes the SVD values for `SMatrix` and From 0d7ad04ab5d190b65bc372bada6b0b29d8c8e085 Mon Sep 17 00:00:00 2001 From: Ronan Arraes Jardim Chagas Date: Thu, 6 Jun 2024 09:56:31 -0300 Subject: [PATCH 5/5] Bump version to 1.9.5 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ff6661eb..ae05f972 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "StaticArrays" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.4" +version = "1.9.5" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"