Skip to content

Commit

Permalink
Use a mutable copy of input if inplace scaling is required (#243)
Browse files Browse the repository at this point in the history
* Use a mutable copy of input if inplace scaling is required

* Convert to Array instead of using similar

* Add test

* Fix type-signature of _plan_mul
  • Loading branch information
jishnub authored May 7, 2024
1 parent 2827377 commit da3e865
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/chebyshevtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ end


# convert x if necessary
@inline _plan_mul!(y::AbstractArray{T}, P::Plan{T}, x::StridedArray{T}) where T = mul!(y, P, x)
@inline _plan_mul!(y::AbstractArray{T}, P::Plan{T}, x::AbstractArray) where T = mul!(y, P, convert(Array{T}, x))
_maybemutablecopy(x::StridedArray{T}, ::Type{T}) where {T} = x
_maybemutablecopy(x, T) = Array{T}(x)
@inline _plan_mul!(y::AbstractArray{T}, P::Plan{T}, x::AbstractArray) where T = mul!(y, P, _maybemutablecopy(x, T))


for op in (:ldiv, :lmul)
Expand Down Expand Up @@ -309,7 +310,8 @@ function mul!(y::AbstractArray{T,N}, P::IChebyshevTransformPlan{T,2,K,false,N},
_icheb2_rescale!(P.plan.region, y)
end

*(P::IChebyshevTransformPlan{T,kind,K,false,N}, x::AbstractArray{T,N}) where {T,kind,K,N} = mul!(similar(x), P, x)
*(P::IChebyshevTransformPlan{T,kind,K,false,N}, x::AbstractArray{T,N}) where {T,kind,K,N} =
mul!(similar(x), P, _maybemutablecopy(x, T))
ichebyshevtransform!(x::AbstractArray, dims...; kwds...) = plan_ichebyshevtransform!(x, dims...; kwds...)*x
ichebyshevtransform(x, dims...; kwds...) = plan_ichebyshevtransform(x, dims...; kwds...)*x

Expand Down
1 change: 1 addition & 0 deletions test/chebyshevtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ using FastTransforms, Test
@testset "immutable vectors" begin
F = plan_chebyshevtransform([1.,2,3])
@test chebyshevtransform(1.0:3) == F * (1:3)
@test ichebyshevtransform(1.0:3) == ichebyshevtransform([1.0:3;])
end

@testset "inv" begin
Expand Down

0 comments on commit da3e865

Please sign in to comment.