diff --git a/Project.toml b/Project.toml index ff6661eb..ae05f972 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "StaticArrays" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.4" +version = "1.9.5" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 0e4b1afe..c4a0437e 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -133,6 +133,10 @@ include("flatten.jl") include("io.jl") include("pinv.jl") +@static if VERSION >= v"1.7" + include("blas.jl") +end + @static if !isdefined(Base, :get_extension) # VERSION < v"1.9-" include("../ext/StaticArraysStatisticsExt.jl") end diff --git a/src/blas.jl b/src/blas.jl new file mode 100644 index 00000000..e019c2a9 --- /dev/null +++ b/src/blas.jl @@ -0,0 +1,141 @@ +# This file contains funtions that uses a direct interface to BLAS library. We use this +# approach to reduce allocations. + +import LinearAlgebra: BLAS, LAPACK, libblastrampoline + +# == Singular Value Decomposition ========================================================== + +# Implement direct call to BLAS functions that computes the SVD values for `SMatrix` and +# `MMatrix` reducing allocations. In this case, we use `MMatrix` to call the library and +# convert the result back to the input type. Since the former does not exit this scope, we +# can reduce allocations. +# +# We are implementing here the following functions: +# +# svdvals(A::SMatrix{M, N, Float64}) where {M, N} +# svdvals(A::SMatrix{M, N, Float32}) where {M, N} +# svdvals(A::MMatrix{M, N, Float64}) where {M, N} +# svdvals(A::MMatrix{M, N, Float32}) where {M, N} +# +for (gesdd, elty) in ((:dgesdd_, :Float64), (:sgesdd_, :Float32)), + (mtype, vtype) in ((SMatrix, SVector), (MMatrix, MVector)) + + blas_func = @eval BLAS.@blasfunc($gesdd) + + @eval begin + function svdvals(A::$mtype{M, N, $elty}) where {M, N} + K = min(M, N) + + # Convert the input to a `MMatrix` and allocate the required arrays. + Am = MMatrix{M, N, $elty}(A) + Sm = MVector{K, $elty}(undef) + + # We compute the `lwork` (size of the work array) by obtaining the maximum value + # from the possibilities shown in: + # https://docs.oracle.com/cd/E19422-01/819-3691/dgesdd.html + lwork = max(8N, 3N + max(M, 7N), 8M, 3M + max(N, 7M)) + work = MVector{lwork, $elty}(undef) + iwork = MVector{8min(M, N), BLAS.BlasInt}(undef) + info = Ref(1) + + @ccall libblastrampoline.$blas_func( + 'N'::Ref{UInt8}, + M::Ref{BLAS.BlasInt}, + N::Ref{BLAS.BlasInt}, + Am::Ptr{$elty}, + M::Ref{BLAS.BlasInt}, + Sm::Ptr{$elty}, + C_NULL::Ptr{C_NULL}, + M::Ref{BLAS.BlasInt}, + C_NULL::Ptr{C_NULL}, + K::Ref{BLAS.BlasInt}, + work::Ptr{$elty}, + lwork::Ref{BLAS.BlasInt}, + iwork::Ptr{BLAS.BlasInt}, + info::Ptr{BLAS.BlasInt}, + 1::Clong + )::Cvoid + + # Check if the return result of the function. + LAPACK.chklapackerror(info.x) + + # Convert the vector to static arrays and return. + S = $vtype{K, $elty}(Sm) + + return S + end + end +end + +# For matrices with interger numbers, we should promote them to float and call `svdvals`. +@inline svdvals(A::StaticMatrix{<: Any, <: Any, <: Integer}) = svdvals(float(A)) + +# Implement direct call to BLAS functions that computes the SVD for `SMatrix` and `MMatrix` +# reducing allocations. In this case, we use `MMatrix` to call the library and convert the +# result back to the input type. Since the former does not exit this scope, we can reduce +# allocations. +# +# We are implementing here the following functions: +# +# _svd(A::SMatrix{M, N, Float64}, full::Val{false}) where {M, N} +# _svd(A::SMatrix{M, N, Float64}, full::Val{true}) where {M, N} +# _svd(A::SMatrix{M, N, Float32}, full::Val{false}) where {M, N} +# _svd(A::SMatrix{M, N, Float32}, full::Val{true}) where {M, N} +# _svd(A::MMatrix{M, N, Float64}, full::Val{false}) where {M, N} +# _svd(A::MMatrix{M, N, Float64}, full::Val{true}) where {M, N} +# _svd(A::MMatrix{M, N, Float32}, full::Val{false}) where {M, N} +# _svd(A::MMatrix{M, N, Float32}, full::Val{true}) where {M, N} +# +for (gesvd, elty) in ((:dgesvd_, :Float64), (:sgesvd_, :Float32)), + full in (false, true), + (mtype, vtype) in ((SMatrix, SVector), (MMatrix, MVector)) + + blas_func = @eval BLAS.@blasfunc($gesvd) + + @eval begin + function _svd(A::$mtype{M, N, $elty}, full::Val{$full}) where {M, N} + K = min(M, N) + + # Convert the input to a `MMatrix` and allocate the required arrays. + Am = MMatrix{M, N, $elty}(A) + Um = MMatrix{M, $(full ? :M : :K), $elty}(undef) + Sm = MVector{K, $elty}(undef) + Vtm = MMatrix{$(full ? :N : :K), N, $elty}(undef) + lwork = max(3min(M, N) + max(M, N), 5min(M, N)) + work = MVector{lwork, $elty}(undef) + info = Ref(1) + + @ccall libblastrampoline.$blas_func( + $(full ? 'A' : 'S')::Ref{UInt8}, + $(full ? 'A' : 'S')::Ref{UInt8}, + M::Ref{BLAS.BlasInt}, + N::Ref{BLAS.BlasInt}, + Am::Ptr{$elty}, + M::Ref{BLAS.BlasInt}, + Sm::Ptr{$elty}, + Um::Ptr{$elty}, + M::Ref{BLAS.BlasInt}, + Vtm::Ptr{$elty}, + $(full ? :N : :K)::Ref{BLAS.BlasInt}, + work::Ptr{$elty}, + lwork::Ref{BLAS.BlasInt}, + info::Ptr{BLAS.BlasInt}, + 1::Clong, + 1::Clong + )::Cvoid + + # Check if the return result of the function. + LAPACK.chklapackerror(info.x) + + # Convert the matrices to the correct type and return. + U = $mtype{M, $(full ? :M : :K), $elty}(Um) + S = $vtype{K, $elty}(Sm) + Vt = $mtype{$(full ? :N : :K), N, $elty}(Vtm) + + return SVD(U, S, Vt) + end + end +end + +# For matrices with interger numbers, we should promote them to float and call `svd`. +@inline svd(A::StaticMatrix{<: Any, <: Any, <: Integer}) = svd(float(A)) diff --git a/src/svd.jl b/src/svd.jl index d2af6e90..b8d1e0e2 100644 --- a/src/svd.jl +++ b/src/svd.jl @@ -73,3 +73,4 @@ function diagmult(sd, sB, d, B) ind = SOneTo(sd[1]) return isa(B, AbstractVector) ? Diagonal(d)*B[ind] : Diagonal(d)*B[ind,:] end + diff --git a/test/svd.jl b/test/svd.jl index 0df776f0..d377dc3e 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -2,8 +2,10 @@ using StaticArrays, Test, LinearAlgebra @testset "SVD factorization" begin m3 = @SMatrix Float64[3 9 4; 6 6 2; 3 7 9] + m3_f32 = @SMatrix Float32[3 9 4; 6 6 2; 3 7 9] m3c = ComplexF64.(m3) m23 = @SMatrix Float64[3 9 4; 6 6 2] + m23_f32 = @SMatrix Float32[3 9 4; 6 6 2] m_sing = @SMatrix [2.0 3.0 5.0; 4.0 6.0 10.0; 1.0 1.0 1.0] m_sing2 = @SMatrix [1 1; 1 0; 0 1] v = @SVector [1, 2, 3] @@ -18,6 +20,7 @@ using StaticArrays, Test, LinearAlgebra @testinf svdvals((@SMatrix [2 -2; 1 1]) / sqrt(2)) ≊ [2, 1] @testinf svdvals(m3) ≊ svdvals(Matrix(m3)) + @testinf svdvals(m3_f32) ≊ svdvals(Matrix(m3_f32)) @testinf svdvals(m3c) isa SVector{3,Float64} @testinf svd(m3).U::StaticMatrix ≊ svd(Matrix(m3)).U @@ -25,9 +28,14 @@ using StaticArrays, Test, LinearAlgebra @testinf svd(m3).V::StaticMatrix ≊ svd(Matrix(m3)).V @testinf svd(m3).Vt::StaticMatrix ≊ svd(Matrix(m3)).Vt - @testinf svd(@SMatrix [2 0; 0 0]).U === one(SMatrix{2,2}) - @testinf svd(@SMatrix [2 0; 0 0]).S === SVector(2.0, 0.0) - @testinf svd(@SMatrix [2 0; 0 0]).Vt === one(SMatrix{2,2}) + @test svd(m3_f32).U::StaticMatrix ≈ svd(Matrix(m3_f32)).U atol = 5e-7 + @test svd(m3_f32).S::StaticVector ≈ svd(Matrix(m3_f32)).S atol = 5e-7 + @test svd(m3_f32).V::StaticMatrix ≈ svd(Matrix(m3_f32)).V atol = 5e-7 + @test svd(m3_f32).Vt::StaticMatrix ≈ svd(Matrix(m3_f32)).Vt atol = 5e-7 + + @testinf svd(@SMatrix [2 0; 0 0]).U ≊ one(SMatrix{2,2}) + @testinf svd(@SMatrix [2 0; 0 0]).S ≊ SVector(2.0, 0.0) + @testinf svd(@SMatrix [2 0; 0 0]).Vt ≊ one(SMatrix{2,2}) @testinf svd((@SMatrix [2 -2; 1 1]) / sqrt(2)).U ≊ [-1 0; 0 1] @testinf svd((@SMatrix [2 -2; 1 1]) / sqrt(2)).S ≊ [2, 1] @@ -41,6 +49,16 @@ using StaticArrays, Test, LinearAlgebra @testinf svd(m23').S ≊ svd(Matrix(m23')).S @testinf svd(m23').Vt ≊ svd(Matrix(m23')).Vt + @test svd(m23_f32).U::StaticMatrix ≈ svd(Matrix(m23_f32)).U atol = 5e-7 + @test svd(m23_f32).S::StaticVector ≈ svd(Matrix(m23_f32)).S atol = 5e-7 + @test svd(m23_f32).V::StaticMatrix ≈ svd(Matrix(m23_f32)).V atol = 5e-7 + @test svd(m23_f32).Vt::StaticMatrix ≈ svd(Matrix(m23_f32)).Vt atol = 5e-7 + + @test svd(m23_f32').U::StaticMatrix ≈ svd(Matrix(m23_f32')).U atol = 5e-7 + @test svd(m23_f32').S::StaticVector ≈ svd(Matrix(m23_f32')).S atol = 5e-7 + @test svd(m23_f32').V::StaticMatrix ≈ svd(Matrix(m23_f32')).V atol = 5e-7 + @test svd(m23_f32').Vt::StaticMatrix ≈ svd(Matrix(m23_f32')).Vt atol = 5e-7 + @testinf svd(m23, full=true).U::StaticMatrix ≊ svd(Matrix(m23), full=true).U @testinf svd(m23, full=true).S::StaticVector ≊ svd(Matrix(m23), full=true).S @testinf svd(m23, full=true).Vt::StaticMatrix ≊ svd(Matrix(m23), full=true).Vt