Skip to content

Commit

Permalink
Support matrix-valued forward recurrence (#14)
Browse files Browse the repository at this point in the history
* Support dims keywords for matrix coefficients, Mutate first argument

* Add tests

* Support matrix-valued forward recurrence

* Make polynomialtype take just coeffs and var types

* Update forward.jl

* Update Project.toml
  • Loading branch information
dlfivefifty authored Dec 21, 2024
1 parent 26cda4b commit 8bdd81c
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RecurrenceRelationships"
uuid = "807425ed-42ea-44d6-a357-6771516d7b2c"
authors = ["Sheehan Olver <[email protected]>"]
version = "0.2.0-dev"
version = "0.2.0"

[weakdeps]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand Down
11 changes: 10 additions & 1 deletion ext/RecurrenceRelationshipsLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
module RecurrenceRelationshipsLinearAlgebraExt
using RecurrenceRelationships, LinearAlgebra

import RecurrenceRelationships: olver
import RecurrenceRelationships: olver, forwardrecurrence!

olver(T::Tridiagonal, f, n...; kwds...) = olver(T.dl, T.d, T.du, f, n...; kwds...)
olver(T::SymTridiagonal, f, n...; kwds...) = olver(T.ev, T.dv, T.ev, f, n...; kwds...)

function forwardrecurrence!(v::AbstractVector{T}, A::AbstractVector, B::AbstractVector, C::AbstractVector, x::AbstractMatrix, p0=one(x)) where T
N = length(v)
N == 0 && return v
length(A)+1 N && length(B)+1 N && length(C)+1 N || throw(ArgumentError("A, B, C must contain at least $(N-1) entries"))
p1 = convert(T, N == 1 ? p0 : muladd(A[1],x,B[1]*I)*p0) # avoid accessing A[1]/B[1] if empty
forwardrecurrence!(v, A, B, C, x, convert(T, p0), p1)
end


end
6 changes: 3 additions & 3 deletions src/RecurrenceRelationships.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ module RecurrenceRelationships
export forwardrecurrence, forwardrecurrence!, clenshaw, clenshaw!, olver, olver!

# choose the type correctly for polynomials in a variable
polynomialtype(::Type{T}) where T = typeof(zero(T)^2+1)
polynomialtype(::Type{N}) where N<:Number = N
polynomialtype(a::Type...) = promote_type(map(polynomialtype, a)...)
polynomialtype(::Type{C}, ::Type{X}) where {C,X} = typeof(zero(C)*zero(X)^2+one(C)*one(X))
polynomialtype(a::Type{<:Number}, b::Type{N}) where N<:AbstractMatrix{<:Number} = N # TODO: use EltypeExtensions.jl
polynomialtype(a::Type) = polynomialtype(a, a)


include("forward.jl")
Expand Down
5 changes: 2 additions & 3 deletions src/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ function forwardrecurrence!(v::AbstractVector{T}, A::AbstractVector, B::Abstract
forwardrecurrence!(v, A, B, C, x, convert(T, p0), p1)
end


Base.@propagate_inbounds forwardrecurrence_next(n, A, B, C, x, p0, p1) = muladd(muladd(A[n],x,B[n]), p1, -C[n]*p0)


Expand Down Expand Up @@ -49,10 +48,10 @@ where `A`, `B`, and `C` are `AbstractVector`s containing the first form recurren
i.e. it returns
"""
forwardrecurrence(N::Integer, A::AbstractVector, B::AbstractVector, C::AbstractVector, x) =
forwardrecurrence!(Vector{polynomialtype(eltype(A),eltype(B),eltype(C),typeof(x))}(undef, N), A, B, C, x)
forwardrecurrence!(Vector{polynomialtype(promote_type(eltype(A),eltype(B),eltype(C)),typeof(x))}(undef, N), A, B, C, x)

forwardrecurrence(N::Integer, A::AbstractVector, B::AbstractVector, C::AbstractVector) =
forwardrecurrence!(Vector{polynomialtype(eltype(A),eltype(B),eltype(C))}(undef, N), A, B, C)
forwardrecurrence!(Vector{polynomialtype(promote_type(eltype(A),eltype(B),eltype(C)))}(undef, N), A, B, C)


forwardrecurrence(A::AbstractVector, B::AbstractVector, C::AbstractVector, x...) = forwardrecurrence(min(length(A), length(B), length(C)), A, B, C, x...)
Expand Down
7 changes: 7 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,4 +300,11 @@ end
N = 5
A, B, C = Fill(2,N-1), Zeros{Int}(N-1), Ones{Int}(N)
@test @inferred(forwardrecurrence(N, A, B, C, x)) == [1,2x,4x^2-1, 8x^3-4x, 16x^4 - 12x^2 + 1]
end

@testset "Matrix" begin
N = 5
A, B, C = Fill(2,N-1), Zeros{Int}(N-1), Ones{Int}(N)
X = randn(6,6)
@test forwardrecurrence(N, A, B, C, X) [I(6), 2X, 4X^2-I, 8X^3-4X, 16X^4 - 12X^2 + I]
end

0 comments on commit 8bdd81c

Please sign in to comment.