Skip to content

Commit

Permalink
Support matrix-valued forward recurrence
Browse files Browse the repository at this point in the history
  • Loading branch information
dlfivefifty committed Dec 19, 2024
1 parent 82125b0 commit c180a93
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 3 deletions.
11 changes: 10 additions & 1 deletion ext/RecurrenceRelationshipsLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
module RecurrenceRelationshipsLinearAlgebraExt
using RecurrenceRelationships, LinearAlgebra

import RecurrenceRelationships: olver
import RecurrenceRelationships: olver, forwardrecurrence!

olver(T::Tridiagonal, f, n...; kwds...) = olver(T.dl, T.d, T.du, f, n...; kwds...)
olver(T::SymTridiagonal, f, n...; kwds...) = olver(T.ev, T.dv, T.ev, f, n...; kwds...)

function forwardrecurrence!(v::AbstractVector{T}, A::AbstractVector, B::AbstractVector, C::AbstractVector, x::AbstractMatrix, p0=one(x)) where T
N = length(v)
N == 0 && return v
length(A)+1 N && length(B)+1 N && length(C)+1 N || throw(ArgumentError("A, B, C must contain at least $(N-1) entries"))
p1 = convert(T, N == 1 ? p0 : muladd(A[1],x,B[1]*I)*p0) # avoid accessing A[1]/B[1] if empty
forwardrecurrence!(v, A, B, C, x, convert(T, p0), p1)
end


end
5 changes: 4 additions & 1 deletion src/RecurrenceRelationships.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ export forwardrecurrence, forwardrecurrence!, clenshaw, clenshaw!, olver, olver!
# choose the type correctly for polynomials in a variable
polynomialtype(::Type{T}) where T = typeof(zero(T)^2+1)
polynomialtype(::Type{N}) where N<:Number = N
polynomialtype(a::Type...) = promote_type(map(polynomialtype, a)...)
polynomialtype(::Type{N}) where N<:AbstractMatrix{<:Number} = N

Check warning on line 7 in src/RecurrenceRelationships.jl

View check run for this annotation

Codecov / codecov/patch

src/RecurrenceRelationships.jl#L7

Added line #L7 was not covered by tests
polynomialtype(a::Type, b::Type) = promote_type(polynomialtype(a), polynomialtype(b))
polynomialtype(a::Type, b::Type, c::Type...) = polynomialtype(polynomialtype(a, b), c...)
polynomialtype(a::Type{<:Number}, b::Type{N}) where N<:AbstractMatrix{<:Number} = N


include("forward.jl")
Expand Down
1 change: 0 additions & 1 deletion src/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ function forwardrecurrence!(v::AbstractVector{T}, A::AbstractVector, B::Abstract
forwardrecurrence!(v, A, B, C, x, convert(T, p0), p1)
end


Base.@propagate_inbounds forwardrecurrence_next(n, A, B, C, x, p0, p1) = muladd(muladd(A[n],x,B[n]), p1, -C[n]*p0)


Expand Down
7 changes: 7 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,4 +281,11 @@ end
N = 5
A, B, C = Fill(2,N-1), Zeros{Int}(N-1), Ones{Int}(N)
@test @inferred(forwardrecurrence(N, A, B, C, x)) == [1,2x,4x^2-1, 8x^3-4x, 16x^4 - 12x^2 + 1]
end

@testset "Matrix" begin
N = 5
A, B, C = Fill(2,N-1), Zeros{Int}(N-1), Ones{Int}(N)
X = randn(6,6)
@test forwardrecurrence(N, A, B, C, X) [I(6), 2X, 4X^2-I, 8X^3-4X, 16X^4 - 12X^2 + I]
end

0 comments on commit c180a93

Please sign in to comment.