Skip to content

Commit

Permalink
Support Enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Oct 1, 2024
1 parent 7ff795b commit 0699100
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 48 deletions.
6 changes: 3 additions & 3 deletions benchmark/groups/pinn.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
using Lux, Random, Zygote
using Lux, Zygote, Enzyme

const input = 2
const hidden = 16

model = Chain(Dense(input => hidden, exp),
Dense(hidden => hidden, exp),
model = Chain(Dense(input => hidden, Lux.relu),
Dense(hidden => hidden, Lux.relu),
Dense(hidden => 1),
first)

Expand Down
8 changes: 8 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,11 @@ for f in (
@eval @opt_out rrule(::typeof($f), x::$tlhs, y::$trhs)
end
end

# Multi-argument functions

@opt_out frule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar)
@opt_out rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar)

@opt_out frule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar, more::TaylorScalar...)
@opt_out rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar, more::TaylorScalar...)
9 changes: 8 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
[deps]
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
Enzyme = "0.13"
49 changes: 49 additions & 0 deletions test/downstream.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using LinearAlgebra
import DifferentiationInterface
using DifferentiationInterface: AutoZygote, AutoEnzyme
import Zygote, Enzyme
using FiniteDiff: finite_difference_derivative

DI = DifferentiationInterface
backend = AutoZygote()
# backend = AutoEnzyme(; mode = Enzyme.Reverse, function_annotation = Enzyme.Const)

@testset "Zygote-over-TaylorDiff on same variable" begin
# Scalar functions
some_number = 0.7
some_numbers = [0.3, 0.4, 0.1]
for f in (exp, log, sqrt, sin, asin, sinh, asinh, x -> x^3)
@test DI.derivative(x -> derivative(f, x, 2), backend, some_number)
derivative(f, some_number, 3)
@test DI.jacobian(x -> derivative.(f, x, 2), backend, some_numbers)
diagm(derivative.(f, some_numbers, 3))
end

# Vector functions
g(x) = x[1] * x[1] + x[2] * x[2]
@test DI.gradient(x -> derivative(g, x, [1.0, 0.0], 1), backend, [1.0, 2.0])
[2.0, 0.0]

# Matrix functions
some_matrix = [0.7 0.1; 0.4 0.2]
f(x) = sum(exp.(x), dims = 1)
dfdx1(x) = derivative(f, x, [1.0, 0.0], 1)
dfdx2(x) = derivative(f, x, [0.0, 1.0], 1)
res(x) = sum(dfdx1(x) .+ 2 * dfdx2(x))
grad = DI.gradient(res, backend, some_matrix)
@test grad [1 0; 0 2] * exp.(some_matrix)
end

@testset "Zygote-over-TaylorDiff on different variable" begin
linear_model(x, p, b) = exp.(b + p * x + b)[1]
loss_taylor(x, p, b, v) = derivative(x -> linear_model(x, p, b), x, v, 1)
ε = cbrt(eps(Float64))
loss_finite(x, p, b, v) = (linear_model(x + ε * v, p, b) -
linear_model(x - ε * v, p, b)) / (2 * ε)
let some_x = [0.58, 0.36], some_v = [0.23, 0.11], some_p = [0.49 0.96], some_b = [0.88]
@test DI.gradient(
p -> loss_taylor(some_x, p, some_b, some_v), backend, some_p)
DI.gradient(
p -> loss_finite(some_x, p, some_b, some_v), backend, some_p)
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ using Test

include("primitive.jl")
include("derivative.jl")
include("zygote.jl")
include("downstream.jl")
# include("lux.jl")
43 changes: 0 additions & 43 deletions test/zygote.jl

This file was deleted.

0 comments on commit 0699100

Please sign in to comment.