Skip to content

Commit

Permalink
Improve readability
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Oct 12, 2024
1 parent 44c3ab0 commit 9ef7cdb
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 109 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.2.5"
[deps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

Expand Down
2 changes: 1 addition & 1 deletion src/TaylorDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ can_taylorize(::Type) = false
" If the type behaves as a scalar, define TaylorDiff.can_taylorize(::Type{$V}) = true."))
end

include("utils.jl")
include("scalar.jl")
include("array.jl")
include("primitive.jl")
include("utils.jl")
include("codegen.jl")
include("derivative.jl")
include("chainrules.jl")
Expand Down
172 changes: 64 additions & 108 deletions src/primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@ import Base: abs, abs2
import Base: exp, exp2, exp10, expm1, log, log2, log10, log1p, inv, sqrt, cbrt
import Base: sin, cos, tan, cot, sec, csc, sinh, cosh, tanh, coth, sech, csch, sinpi, cospi
import Base: asin, acos, atan, acot, asec, acsc, asinh, acosh, atanh, acoth, asech, acsch
import Base: sinc, cosc
import Base: +, -, *, /, \, ^, >, <, >=, <=, ==
import Base: hypot, max, min
import Base: tail
import Base: convert, promote_rule
import Base: sinc, cosc, hypot, max, min, literal_pow

Taylor = Union{TaylorScalar, TaylorArray}

Expand All @@ -22,7 +19,7 @@ Taylor = Union{TaylorScalar, TaylorArray}

@inline flatten(t::Taylor) = (value(t), partials(t)...)

function promote_rule(::Type{TaylorScalar{T, P}},
function Base.promote_rule(::Type{TaylorScalar{T, P}},
::Type{S}) where {T, S, P}
TaylorScalar{promote_type(T, S), P}
end
Expand All @@ -35,78 +32,49 @@ end

## Delegated

@inline +(t::TaylorScalar) = t
@inline -(t::TaylorScalar) = TaylorScalar(-value(t), .-partials(t))
@inline sqrt(t::TaylorScalar) = t^0.5
@inline cbrt(t::TaylorScalar) = ^(t, 1 / 3)
@inline inv(t::TaylorScalar) = one(t) / t
@inline sinpi(t::TaylorScalar) = sin* t)
@inline cospi(t::TaylorScalar) = cos* t)
@inline exp10(t::TaylorScalar) = exp(t * log(10))
@inline exp2(t::TaylorScalar) = exp(t * log(2))
@inline expm1(t::TaylorScalar) = TaylorScalar(expm1(value(t)), partials(exp(t)))

for func in (:exp, :expm1, :exp2, :exp10)
@eval @generated function $func(t::TaylorScalar{T, P}) where {T, P}
v = [Symbol("v$i") for i in 0:P]
ex = quote
$(Expr(:meta, :inline))
p = value(t)
f = flatten(t)
v0 = $($(QuoteNode(func)) == :expm1 ? :(exp(p)) : :($$func(p)))
end
for i in 1:P
push!(ex.args,
:(
$(v[begin + i]) = +($([:($(i - j) * $(v[begin + j]) *
f[begin + $(i - j)])
for j in 0:(i - 1)]...)) / $i
))
if $(QuoteNode(func)) == :exp2
push!(ex.args, :($(v[begin + i]) *= log(2)))
elseif $(QuoteNode(func)) == :exp10
push!(ex.args, :($(v[begin + i]) *= log(10)))
end
end
if $(QuoteNode(func)) == :expm1
push!(ex.args, :(v0 = expm1(f[1])))
## Hand-written exp, sin, cos

@to_static function exp(t::TaylorScalar{T, P}) where {P, T}
f = flatten(t)
v[0] = exp(f[0])
for i in 1:P
v[i] = zero(T)
for j in 0:(i - 1)
v[i] += (i - j) * v[j] * f[i - j]
end
push!(ex.args, :(TaylorScalar(tuple($(v...)))))
return :(@inbounds $ex)
v[i] /= i
end
return TaylorScalar(v)
end

for func in (:sin, :cos)
@eval @generated function $func(t::TaylorScalar{T, P}) where {T, P}
s = [Symbol("s$i") for i in 0:P]
c = [Symbol("c$i") for i in 0:P]
ex = quote
$(Expr(:meta, :inline))
f = flatten(t)
s0 = sin(f[1])
c0 = cos(f[1])
end
@eval @to_static function $func(t::TaylorScalar{T, P}) where {T, P}
f = flatten(t)
s[0], c[0] = sincos(f[0])
for i in 1:P
push!(ex.args,
:($(s[begin + i]) = +($([:(
$(i - j) * $(c[begin + j]) *
f[begin + $(i - j)]) for j in 0:(i - 1)]...)) /
$i)
)
push!(ex.args,
:($(c[begin + i]) = +($([:(
$(i - j) * $(s[begin + j]) *
f[begin + $(i - j)]) for j in 0:(i - 1)]...)) /
-$i)
)
end
if $(QuoteNode(func)) == :sin
push!(ex.args, :(TaylorScalar(tuple($(s...)))))
else
push!(ex.args, :(TaylorScalar(tuple($(c...)))))
s[i] = zero(T)
c[i] = zero(T)
for j in 0:(i - 1)
s[i] += (i - j) * c[j] * f[i - j]
c[i] -= (i - j) * s[j] * f[i - j]
end
s[i] /= i
c[i] /= i
end
return :(@inbounds $ex)
return $(func == :sin ? :(TaylorScalar(s)) : :(TaylorScalar(c)))
end
end

@inline sinpi(t::TaylorScalar) = sin* t)
@inline cospi(t::TaylorScalar) = cos* t)

# Binary

## Easy case
Expand Down Expand Up @@ -136,63 +104,51 @@ end
@inline -(a::TaylorScalar, b::TaylorScalar) = TaylorScalar(
value(a) - value(b), map(-, partials(a), partials(b)))

@generated function *(a::TaylorScalar{T, N}, b::TaylorScalar{T, N}) where {T, N}
return quote
$(Expr(:meta, :inline))
va, vb = flatten(a), flatten(b)
v = tuple($([:(
+($([:(va[begin + $j] * vb[begin + $(i - j)]) for j in 0:i]...))
) for i in 0:N]...))
@inbounds TaylorScalar(v)
@to_static function *(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
va, vb = flatten(a), flatten(b)
for i in 0:P
v[i] = zero(T)
for j in 0:i
v[i] += va[j] * vb[i - j]
end
end
TaylorScalar(v)
end

@generated function /(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
v = [Symbol("v$i") for i in 0:P]
ex = quote
$(Expr(:meta, :inline))
va, vb = flatten(a), flatten(b)
v0 = va[1] / vb[1]
b0 = vb[1]
end
@to_static function /(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
va, vb = flatten(a), flatten(b)
v[0] = va[0] / vb[0]
for i in 1:P
push!(ex.args,
:(
$(v[begin + i]) = (va[begin + $i] -
+($([:($(v[begin + j]) *
vb[begin + $(i - j)])
for j in 0:(i - 1)]...))) / b0
)
)
v[i] = va[i]
for j in 0:(i - 1)
v[i] -= vb[i - j] * v[j]
end
v[i] /= vb[0]
end
push!(ex.args, :(TaylorScalar(tuple($(v...)))))
return :(@inbounds $ex)
TaylorScalar(v)
end

@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{0}) = one(x)
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{1}) = x
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{2}) = x*x
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{3}) = x*x*x
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{-1}) = inv(x)
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{-2}) = (i=inv(x); i*i)

for R in (Integer, Real)
@eval @generated function ^(t::TaylorScalar{T, P}, n::S) where {S <: $R, T, P}
v = [Symbol("v$i") for i in 0:P]
ex = quote
$(Expr(:meta, :inline))
f = flatten(t)
f0 = f[1]
v0 = ^(f0, n)
end
@eval @to_static function ^(t::TaylorScalar{T, P}, n::S) where {S <: $R, T, P}
f = flatten(t)
v[0] = f[0]^n
for i in 1:P
push!(ex.args,
:(
$(v[begin + i]) = +($([:(
(n * $(i - j) - $j) * $(v[begin + j]) *
f[begin + $(i - j)]
) for j in 0:(i - 1)]...)) / ($i * f0)
))
v[i] = zero(T)
for j in 0:(i - 1)
v[i] += (n * (i - j) - j) * v[j] * f[i - j]
end
v[i] /= (i * f[0])
end
push!(ex.args, :(TaylorScalar(tuple($(v...)))))
return :(@inbounds $ex)
end
@eval function ^(a::S, t::TaylorScalar{T, N}) where {S <: $R, T, N}
exp(t * log(a))
return TaylorScalar(v)
end
@eval ^(a::S, t::TaylorScalar) where {S <: $R} = exp(t * log(a))
end

^(t::TaylorScalar, s::TaylorScalar) = exp(s * log(t))
Expand Down
63 changes: 63 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using ChainRules
using ChainRulesCore
using Symbolics: @variables, @rule, unwrap, isdiv
using SymbolicUtils.Code: toexpr
using MacroTools
using MacroTools: prewalk, postwalk

"""
Pick a strategy for raising the derivative of a function. If the derivative is like 1 over something, raise with the division rule; otherwise, raise with the multiplication rule.
Expand Down Expand Up @@ -45,3 +47,64 @@ function define_unary_function(func, m)
end
end
end

tuplen(::Type{NTuple{N, T}}) where {N, T} = N
function interpolate(ex::Expr, dict)
func = ex.args[1]
args = map(x -> interpolate(x, dict), ex.args[2:end])
getproperty(Base, func)(args...)
end
interpolate(ex::Symbol, dict) = get(dict, ex, ex)
interpolate(ex::Any, _) = ex

function unroll_loop(start, stop, var, body, d)
ex = Expr(:block)
start = interpolate(start, d)
stop = interpolate(stop, d)
for i in start:stop
iter = prewalk(x -> x === var ? i : x, body)
args = filter(x -> !(x isa LineNumberNode), iter.args)
append!(ex.args, args)
end
ex
end

function process(d, expr)
# Unroll loops
expr = prewalk(expr) do x
@match x begin
for var_ in start_:stop_
body_
end => unroll_loop(start, stop, var, body, d)
_ => x
end
end
# Modify indices
magic_names = (:v, :s, :c)
expr = postwalk(expr) do x
@match x begin
a_[idx_] => a in magic_names ? Symbol(a, idx) : :($a[begin + $idx])
TaylorScalar(v_) => :(TaylorScalar(tuple($([Symbol(v, idx) for idx in 0:d[:P]]...))))
_ => x
end
end
# Add inline meta
return quote
$(Expr(:meta, :inline))
$expr
end
end

macro to_static(def)
dict = splitdef(def)
pairs = Any[]
for symbol in dict[:whereparams]
push!(pairs, :($(QuoteNode(symbol)) => $symbol))
end
esc(quote
@generated function $(dict[:name])($(dict[:args]...)) where {$(dict[:whereparams]...)}
d = Dict($(pairs...))
process(d, $(QuoteNode(dict[:body])))
end
end)
end

0 comments on commit 9ef7cdb

Please sign in to comment.