Skip to content

Commit

Permalink
coeffs instead of derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Oct 11, 2024
1 parent 2b0b1ce commit 44c3ab0
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 172 deletions.
9 changes: 5 additions & 4 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ function rrule(::typeof(partials), t::TaylorArray{T, N, A, P}) where {N, T, A, P
return partials(t), value_pullback
end

function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
i::Integer) where {N, T}
function rrule(::typeof(extract_derivative), t::TaylorScalar{T, P},
q::Val{Q}) where {T, P, Q}
function extract_derivative_pullback(d̄)
NoTangent(), TaylorScalar(zero(T), ntuple(j -> j === i ?: zero(T), Val(N))),
NoTangent(),
TaylorScalar(zero(T), ntuple(j -> j === Q ?* factorial(Q) : zero(T), Val(P))),
NoTangent()
end
return extract_derivative(t, i), extract_derivative_pullback
return extract_derivative(t, q), extract_derivative_pullback
end

function rrule(::typeof(Base.getindex), a::TaylorArray, i::Int...)
Expand Down
2 changes: 1 addition & 1 deletion src/codegen.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
for unary_func in (
+, -, deg2rad, rad2deg,
deg2rad, rad2deg,
sinh, cosh, tanh,
asin, acos, atan, asec, acsc, acot,
log, log10, log1p, log2,
Expand Down
11 changes: 6 additions & 5 deletions src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,18 @@ function derivatives end

# Added to help Zygote infer types
@inline make_seed(x::T, l::T, ::Val{P}) where {T <: Real, P} = TaylorScalar{P}(x, l)
@inline make_seed(x::A, l::A, ::Val{P}) where {A <: AbstractArray, P} = broadcast(make_seed, x, l, Val{P}())
@inline make_seed(x::A, l::A, ::Val{P}) where {A <: AbstractArray, P} = broadcast(
make_seed, x, l, Val{P}())

# `derivative` API: computes the `P - 1`-th derivative of `f` at `x`
@inline derivative(f, x, l, p::Val{P}) where {P} = extract_derivative(
derivatives(f, x, l, p), P)
derivatives(f, x, l, p), p)
@inline derivative(f!, y, x, l, p::Val{P}) where {P} = extract_derivative(
derivatives(f!, y, x, l, p), P)
derivatives(f!, y, x, l, p), p)
@inline derivative!(result, f, x, l, p::Val{P}) where {P} = extract_derivative!(
result, derivatives(f, x, l, p), P)
result, derivatives(f, x, l, p), p)
@inline derivative!(result, f!, y, x, l, p::Val{P}) where {P} = extract_derivative!(
result, derivatives(f!, y, x, l, p), P)
result, derivatives(f!, y, x, l, p), p)

# `derivatives` API: computes all derivatives of `f` at `x` up to p `P - 1`

Expand Down
242 changes: 105 additions & 137 deletions src/primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@ Taylor = Union{TaylorScalar, TaylorArray}

@inline value(t::Taylor) = t.value
@inline partials(t::Taylor) = t.partials
@inline extract_derivative(t::Taylor, i::Integer) = t.partials[i]
@inline extract_derivative(v::AbstractArray{<:TaylorScalar}, i::Integer) = map(
t -> extract_derivative(t, i), v)
@inline extract_derivative(r, i::Integer) = false
@inline function extract_derivative!(result::AbstractArray, v::AbstractArray{T},
i::Integer) where {T <: TaylorScalar}
map!(t -> extract_derivative(t, i), result, v)
end
@inline @generated extract_derivative(t::Taylor, ::Val{P}) where {P} = :(t.partials[P] *
$(factorial(P)))
@inline extract_derivative(a::AbstractArray{<:TaylorScalar}, p) = map(
t -> extract_derivative(t, p), a)
@inline extract_derivative(_, p) = false
@inline extract_derivative!(result, a::AbstractArray{<:TaylorScalar}, p) = map!(
t -> extract_derivative(t, p), result, a)

@inline flatten(t::Taylor) = (value(t), partials(t)...)

Expand All @@ -33,74 +32,75 @@ function (::Type{F})(x::TaylorScalar{T, P}) where {T, P, F <: AbstractFloat}
end

# Unary
@inline +(a::Number, b::TaylorScalar) = TaylorScalar(a + value(b), partials(b))
@inline -(a::Number, b::TaylorScalar) = TaylorScalar(a - value(b), map(-, partials(b)))
@inline *(a::Number, b::TaylorScalar) = TaylorScalar(a * value(b), a .* partials(b))
@inline /(a::Number, b::TaylorScalar) = /(promote(a, b)...)

@inline +(a::TaylorScalar, b::Number) = TaylorScalar(value(a) + b, partials(a))
@inline -(a::TaylorScalar, b::Number) = TaylorScalar(value(a) - b, partials(a))
@inline *(a::TaylorScalar, b::Number) = TaylorScalar(value(a) * b, partials(a) .* b)
@inline /(a::TaylorScalar, b::Number) = TaylorScalar(value(a) / b, partials(a) ./ b)

## Delegated

@inline +(t::TaylorScalar) = t
@inline -(t::TaylorScalar) = TaylorScalar(-value(t), .-partials(t))
@inline sqrt(t::TaylorScalar) = t^0.5
@inline cbrt(t::TaylorScalar) = ^(t, 1 / 3)
@inline inv(t::TaylorScalar) = one(t) / t

for func in (:exp, :expm1, :exp2, :exp10)
@eval @generated function $func(t::TaylorScalar{T, N}) where {T, N}
@eval @generated function $func(t::TaylorScalar{T, P}) where {T, P}
v = [Symbol("v$i") for i in 0:P]
ex = quote
v = flatten(t)
v1 = $($(QuoteNode(func)) == :expm1 ? :(exp(v[1])) : :($$func(v[1])))
$(Expr(:meta, :inline))
p = value(t)
f = flatten(t)
v0 = $($(QuoteNode(func)) == :expm1 ? :(exp(p)) : :($$func(p)))
end
for i in 2:(N + 1)
ex = quote
$ex
$(Symbol('v', i)) = +($([:($(binomial(i - 2, j - 1)) * $(Symbol('v', j)) *
v[$(i + 1 - j)])
for j in 1:(i - 1)]...))
end
for i in 1:P
push!(ex.args,
:(
$(v[begin + i]) = +($([:($(i - j) * $(v[begin + j]) *
f[begin + $(i - j)])
for j in 0:(i - 1)]...)) / $i
))
if $(QuoteNode(func)) == :exp2
ex = :($ex; $(Symbol('v', i)) *= $(log(2)))
push!(ex.args, :($(v[begin + i]) *= log(2)))
elseif $(QuoteNode(func)) == :exp10
ex = :($ex; $(Symbol('v', i)) *= $(log(10)))
push!(ex.args, :($(v[begin + i]) *= log(10)))
end
end
if $(QuoteNode(func)) == :expm1
ex = :($ex; v1 = expm1(v[1]))
push!(ex.args, :(v0 = expm1(f[1])))
end
ex = :($ex; TaylorScalar(tuple($([Symbol('v', i) for i in 1:(N + 1)]...))))
push!(ex.args, :(TaylorScalar(tuple($(v...)))))
return :(@inbounds $ex)
end
end

for func in (:sin, :cos)
@eval @generated function $func(t::TaylorScalar{T, N}) where {T, N}
@eval @generated function $func(t::TaylorScalar{T, P}) where {T, P}
s = [Symbol("s$i") for i in 0:P]
c = [Symbol("c$i") for i in 0:P]
ex = quote
v = flatten(t)
s1 = sin(v[1])
c1 = cos(v[1])
$(Expr(:meta, :inline))
f = flatten(t)
s0 = sin(f[1])
c0 = cos(f[1])
end
for i in 2:(N + 1)
ex = :($ex;
$(Symbol('s', i)) = +($([:($(binomial(i - 2, j - 1)) *
$(Symbol('c', j)) *
v[$(i + 1 - j)]) for j in 1:(i - 1)]...)))
ex = :($ex;
$(Symbol('c', i)) = +($([:($(-binomial(i - 2, j - 1)) *
$(Symbol('s', j)) *
v[$(i + 1 - j)]) for j in 1:(i - 1)]...)))
for i in 1:P
push!(ex.args,
:($(s[begin + i]) = +($([:(
$(i - j) * $(c[begin + j]) *
f[begin + $(i - j)]) for j in 0:(i - 1)]...)) /
$i)
)
push!(ex.args,
:($(c[begin + i]) = +($([:(
$(i - j) * $(s[begin + j]) *
f[begin + $(i - j)]) for j in 0:(i - 1)]...)) /
-$i)
)
end
if $(QuoteNode(func)) == :sin
ex = :($ex; TaylorScalar(tuple($([Symbol('s', i) for i in 1:(N + 1)]...))))
push!(ex.args, :(TaylorScalar(tuple($(s...)))))
else
ex = :($ex; TaylorScalar(tuple($([Symbol('c', i) for i in 1:(N + 1)]...))))
end
return quote
@inbounds $ex
push!(ex.args, :(TaylorScalar(tuple($(c...)))))
end
return :(@inbounds $ex)
end
end

Expand All @@ -109,6 +109,18 @@ end

# Binary

## Easy case

@inline +(a::Number, b::TaylorScalar) = TaylorScalar(a + value(b), partials(b))
@inline -(a::Number, b::TaylorScalar) = TaylorScalar(a - value(b), .-partials(b))
@inline *(a::Number, b::TaylorScalar) = TaylorScalar(a * value(b), a .* partials(b))
@inline /(a::Number, b::TaylorScalar) = /(promote(a, b)...)

@inline +(a::TaylorScalar, b::Number) = TaylorScalar(value(a) + b, partials(a))
@inline -(a::TaylorScalar, b::Number) = TaylorScalar(value(a) - b, partials(a))
@inline *(a::TaylorScalar, b::Number) = TaylorScalar(value(a) * b, partials(a) .* b)
@inline /(a::TaylorScalar, b::Number) = TaylorScalar(value(a) / b, partials(a) ./ b)

const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, RoundingMode)

for op in [:>, :<, :(==), :(>=), :(<=)]
Expand All @@ -126,75 +138,56 @@ end

@generated function *(a::TaylorScalar{T, N}, b::TaylorScalar{T, N}) where {T, N}
return quote
$(Expr(:meta, :inline))
va, vb = flatten(a), flatten(b)
r = tuple($([:(+($([:($(binomial(i - 1, j - 1)) * va[$j] *
vb[$(i + 1 - j)]) for j in 1:i]...)))
for i in 1:(N + 1)]...))
@inbounds TaylorScalar(r[1], r[2:end])
v = tuple($([:(
+($([:(va[begin + $j] * vb[begin + $(i - j)]) for j in 0:i]...))
) for i in 0:N]...))
@inbounds TaylorScalar(v)
end
end

@generated function /(a::TaylorScalar{T, N}, b::TaylorScalar{T, N}) where {T, N}
@generated function /(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
v = [Symbol("v$i") for i in 0:P]
ex = quote
$(Expr(:meta, :inline))
va, vb = flatten(a), flatten(b)
v1 = va[1] / vb[1]
v0 = va[1] / vb[1]
b0 = vb[1]
end
for i in 2:(N + 1)
ex = quote
$ex
$(Symbol('v', i)) = (va[$i] -
+($([:($(binomial(i - 1, j - 1)) * $(Symbol('v', j)) *
vb[$(i + 1 - j)])
for j in 1:(i - 1)]...))) / vb[1]
end
end
ex = quote
$ex
v = tuple($([Symbol('v', i) for i in 1:(N + 1)]...))
TaylorScalar(v)
for i in 1:P
push!(ex.args,
:(
$(v[begin + i]) = (va[begin + $i] -
+($([:($(v[begin + j]) *
vb[begin + $(i - j)])
for j in 0:(i - 1)]...))) / b0
)
)
end
push!(ex.args, :(TaylorScalar(tuple($(v...)))))
return :(@inbounds $ex)
end

for R in (Integer, Real)
@eval @generated function ^(t::TaylorScalar{T, N}, n::S) where {S <: $R, T, N}
@eval @generated function ^(t::TaylorScalar{T, P}, n::S) where {S <: $R, T, P}
v = [Symbol("v$i") for i in 0:P]
ex = quote
v = flatten(t)
w11 = 1
u1 = ^(v[1], n)
end
for k in 1:(N + 1)
ex = quote
$ex
$(Symbol('p', k)) = ^(v[1], n - $(k - 1))
end
$(Expr(:meta, :inline))
f = flatten(t)
f0 = f[1]
v0 = ^(f0, n)
end
for i in 2:(N + 1)
subex = quote
$(Symbol('w', i, 1)) = 0
end
for k in 2:i
subex = quote
$subex
$(Symbol('w', i, k)) = +($([:((n * $(binomial(i - 2, j - 1)) -
$(binomial(i - 2, j - 2))) *
$(Symbol('w', j, k - 1)) *
v[$(i + 1 - j)])
for j in (k - 1):(i - 1)]...))
end
end
ex = quote
$ex
$subex
$(Symbol('u', i)) = +($([:($(Symbol('w', i, k)) * $(Symbol('p', k)))
for k in 2:i]...))
end
end
ex = quote
$ex
v = tuple($([Symbol('u', i) for i in 1:(N + 1)]...))
TaylorScalar(v)
for i in 1:P
push!(ex.args,
:(
$(v[begin + i]) = +($([:(
(n * $(i - j) - $j) * $(v[begin + j]) *
f[begin + $(i - j)]
) for j in 0:(i - 1)]...)) / ($i * f0)
))
end
push!(ex.args, :(TaylorScalar(tuple($(v...)))))
return :(@inbounds $ex)
end
@eval function ^(a::S, t::TaylorScalar{T, N}) where {S <: $R, T, N}
Expand All @@ -204,39 +197,14 @@ end

^(t::TaylorScalar, s::TaylorScalar) = exp(s * log(t))

@generated function raise(f::T, df::TaylorScalar{T, M},
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
return quote
$(Expr(:meta, :inline))
vdf, vt = flatten(df), flatten(t)
partials = tuple($([:(+($([:($(binomial(i - 1, j - 1)) * vdf[$j] *
vt[$(i + 2 - j)]) for j in 1:i]...)))
for i in 1:(M + 1)]...))
@inbounds TaylorScalar(f, partials)
end
@inline function lower(t::TaylorScalar{T, P}) where {T, P}
s = partials(t)
TaylorScalar(ntuple(i -> s[i] * i, Val(P)))
end

raise(::T, df::S, t::TaylorScalar{T, N}) where {S <: Number, T, N} = df * t

@generated function raiseinv(f::T, df::TaylorScalar{T, M},
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
ex = quote
vdf, vt = flatten(df), flatten(t)
v1 = vt[2] / vdf[1]
end
for i in 2:(M + 1)
ex = quote
$ex
$(Symbol('v', i)) = (vt[$(i + 1)] -
+($([:($(binomial(i - 1, j - 1)) * $(Symbol('v', j)) *
vdf[$(i + 1 - j)])
for j in 1:(i - 1)]...))) / vdf[1]
end
end
ex = quote
$ex
v = tuple($([Symbol('v', i) for i in 1:(M + 1)]...))
TaylorScalar(f, v)
end
return :(@inbounds $ex)
@inline function higher(t::TaylorScalar{T, P}) where {T, P}
s = flatten(t)
ntuple(i -> s[i] / i, Val(P + 1))
end
@inline raise(f, df::TaylorScalar, t) = TaylorScalar(f, higher(lower(t) * df))
@inline raise(f, df::Number, t) = df * t
@inline raiseinv(f, df, t) = TaylorScalar(f, higher(lower(t) / df))
Loading

0 comments on commit 44c3ab0

Please sign in to comment.