From 0699100c6f2c891e8e5dc04891892b84024a813b Mon Sep 17 00:00:00 2001 From: Songchen Tan Date: Tue, 1 Oct 2024 14:57:57 -0400 Subject: [PATCH] Support Enzyme --- benchmark/groups/pinn.jl | 6 ++--- src/chainrules.jl | 8 +++++++ test/Project.toml | 9 +++++++- test/downstream.jl | 49 ++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 +- test/zygote.jl | 43 ----------------------------------- 6 files changed, 69 insertions(+), 48 deletions(-) create mode 100644 test/downstream.jl delete mode 100644 test/zygote.jl diff --git a/benchmark/groups/pinn.jl b/benchmark/groups/pinn.jl index 4e8833b..058b58b 100644 --- a/benchmark/groups/pinn.jl +++ b/benchmark/groups/pinn.jl @@ -1,10 +1,10 @@ -using Lux, Random, Zygote +using Lux, Zygote, Enzyme const input = 2 const hidden = 16 -model = Chain(Dense(input => hidden, exp), - Dense(hidden => hidden, exp), +model = Chain(Dense(input => hidden, Lux.relu), + Dense(hidden => hidden, Lux.relu), Dense(hidden => 1), first) diff --git a/src/chainrules.jl b/src/chainrules.jl index 8edbf7a..a6f6245 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -86,3 +86,11 @@ for f in ( @eval @opt_out rrule(::typeof($f), x::$tlhs, y::$trhs) end end + +# Multi-argument functions + +@opt_out frule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar) +@opt_out rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar) + +@opt_out frule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar, more::TaylorScalar...) +@opt_out rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar, more::TaylorScalar...) diff --git a/test/Project.toml b/test/Project.toml index 72bbeb9..6610f61 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,12 @@ [deps] +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[compat] +Enzyme = "0.13" diff --git a/test/downstream.jl b/test/downstream.jl new file mode 100644 index 0000000..17af58d --- /dev/null +++ b/test/downstream.jl @@ -0,0 +1,49 @@ +using LinearAlgebra +import DifferentiationInterface +using DifferentiationInterface: AutoZygote, AutoEnzyme +import Zygote, Enzyme +using FiniteDiff: finite_difference_derivative + +DI = DifferentiationInterface +backend = AutoZygote() +# backend = AutoEnzyme(; mode = Enzyme.Reverse, function_annotation = Enzyme.Const) + +@testset "Zygote-over-TaylorDiff on same variable" begin + # Scalar functions + some_number = 0.7 + some_numbers = [0.3, 0.4, 0.1] + for f in (exp, log, sqrt, sin, asin, sinh, asinh, x -> x^3) + @test DI.derivative(x -> derivative(f, x, 2), backend, some_number) ≈ + derivative(f, some_number, 3) + @test DI.jacobian(x -> derivative.(f, x, 2), backend, some_numbers) ≈ + diagm(derivative.(f, some_numbers, 3)) + end + + # Vector functions + g(x) = x[1] * x[1] + x[2] * x[2] + @test DI.gradient(x -> derivative(g, x, [1.0, 0.0], 1), backend, [1.0, 2.0]) ≈ + [2.0, 0.0] + + # Matrix functions + some_matrix = [0.7 0.1; 0.4 0.2] + f(x) = sum(exp.(x), dims = 1) + dfdx1(x) = derivative(f, x, [1.0, 0.0], 1) + dfdx2(x) = derivative(f, x, [0.0, 1.0], 1) + res(x) = sum(dfdx1(x) .+ 2 * dfdx2(x)) + grad = DI.gradient(res, backend, some_matrix) + @test grad ≈ [1 0; 0 2] * exp.(some_matrix) +end + +@testset "Zygote-over-TaylorDiff on different variable" begin + linear_model(x, p, b) = exp.(b + p * x + b)[1] + loss_taylor(x, p, b, v) = derivative(x -> linear_model(x, p, b), x, v, 1) + ε = cbrt(eps(Float64)) + loss_finite(x, p, b, v) = (linear_model(x + ε * v, p, b) - + linear_model(x - ε * v, p, b)) / (2 * ε) + let some_x = [0.58, 0.36], some_v = [0.23, 0.11], some_p = [0.49 0.96], some_b = [0.88] + @test DI.gradient( + p -> loss_taylor(some_x, p, some_b, some_v), backend, some_p) ≈ + DI.gradient( + p -> loss_finite(some_x, p, some_b, some_v), backend, some_p) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index a28c383..2088824 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,5 +3,5 @@ using Test include("primitive.jl") include("derivative.jl") -include("zygote.jl") +include("downstream.jl") # include("lux.jl") diff --git a/test/zygote.jl b/test/zygote.jl deleted file mode 100644 index 3dccedd..0000000 --- a/test/zygote.jl +++ /dev/null @@ -1,43 +0,0 @@ -using LinearAlgebra -import Zygote # use qualified import to avoid conflict with TaylorDiff - -@testset "Zygote-over-TaylorDiff on same variable" begin - # Scalar functions - some_number = 0.7 - some_numbers = [0.3 0.4 0.1;] - for f in (exp, log, sqrt, sin, asin, sinh, asinh, x -> x^3) - @test Zygote.gradient(derivative, f, some_number, 2)[2] ≈ - derivative(f, some_number, 3) - @test Zygote.jacobian(broadcast, derivative, f, some_numbers, 2)[3] ≈ - diagm(vec(derivative.(f, some_numbers, 3))) - end - - # Vector functions - g(x) = x[1] * x[1] + x[2] * x[2] - @test Zygote.gradient(derivative, g, [1.0, 2.0], [1.0, 0.0], 1)[2] ≈ [2.0, 0.0] - - # Matrix functions - some_matrix = [0.7 0.1; 0.4 0.2] - f(x) = sum(exp.(x), dims = 1) - dfdx1(x) = derivative(f, x, [1.0, 0.0], 1) - dfdx2(x) = derivative(f, x, [0.0, 1.0], 1) - res(x) = sum(dfdx1(x) .+ 2 * dfdx2(x)) - grads = Zygote.gradient(res, some_matrix) - @test grads[1] ≈ [1 0; 0 2] * exp.(some_matrix) -end - -@testset "Zygote-over-TaylorDiff on different variable" begin - Zygote.gradient( - p -> derivative(x -> sum(exp.(x + p)), [1.0, 1.0], [1.0, 0.0], 1), [0.5, 0.7]) - Zygote.gradient( - p -> derivative(x -> sum(exp.(p + x)), [1.0, 1.0], [1.0, 0.0], 1), [0.5, 0.7]) - linear_model(x, p, b) = exp.(b + p * x + b)[1] - some_x, some_v, some_p, some_b = [0.58, 0.36], [0.23, 0.11], [0.49 0.96], [0.88] - loss_taylor(p) = derivative(x -> linear_model(x, p, some_b), some_x, some_v, 1) - ε = cbrt(eps(Float64)) - loss_finite(p) = - let f = x -> linear_model(x, p, some_b) - (f(some_x + ε * some_v) - f(some_x - ε * some_v)) / 2ε - end - @test Zygote.gradient(loss_taylor, some_p)[1] ≈ Zygote.gradient(loss_finite, some_p)[1] -end