-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Set AD rules #93
base: master
Are you sure you want to change the base?
Set AD rules #93
Conversation
Codecov Report
@@ Coverage Diff @@
## master #93 +/- ##
==========================================
- Coverage 90.96% 89.71% -1.26%
==========================================
Files 11 11
Lines 620 632 +12
==========================================
+ Hits 564 567 +3
- Misses 56 65 +9
Continue to review full report at Codecov.
|
Your code does not seem to work for some examples and gives the wrong result for others. For example: gradient((A, B)->sum(A⊗B), A, B)
gradient((A, B)->sum(kron(A,B)), A, B) Most of our Do you have a reference for your gradients? |
Adds a testing-function for different 'output' dimensions of each factor in the Kronecker product. It defines linear regression models with the sum of squared residuals as a loss function. Currently only works for residuals of scalar outputs. Tests are broken for outputs of higher dimensions.
You're right, I started with the following loss function: function loss(A, B, X)
Z = kron(A, B)*X - y
L = 0.5 * tr(Z' * Z)
return L
end where In I decided to leave similar tests for higher-dimensions, but leave them with |
I've been experimenting with I managed to get the correct values for the pullback: function ChainRulesCore.frule((_, ΔA, ΔB), ::KroneckerSum, A::AbstractMatrix, B::AbstractMatrix)
Ω = (A ⊕ B)
∂Ω = (ΔA ⊕ ΔB)
return Ω, ∂Ω
end
function ChainRulesCore.rrule(::typeof(KroneckerSum), A::AbstractMatrix, B::AbstractMatrix)
function kronecker_sum_pullback(ΔΩ)
∂A = nB .* A + Diagonal(fill(tr(B), nA))
∂B = nA .* B + Diagonal(fill(tr(A), nB))
return (NO_FIELDS, ∂A, ∂B)
end
return (A ⊕ B), kronecker_sum_pullback
end
nA = 3
nB = 2
Ar = rand(nA,nA)
Br = rand(nB,nB)
Y_lazy, back_lazy = Zygote._pullback(⊕, Ar, Br)
Y, back = Zygote._pullback((x,y) -> kron(x, Diagonal(ones(nB))) + kron(Diagonal(ones(nA)), y), Ar, Br) julia> back(Y)[2:end] .≈ back_lazy(Y_lazy)[2:end]
(true, true) Of course, this isn't useful for computing the gradient in more complicated expressions, since |
Note that: ChainRulesCore.rrule(::typeof(KroneckerSum), A::AbstractMatrix, B::AbstractMatrix) overwrites ChainRulesCore.rrule(::typeof(KroneckerProduct), A::AbstractMatrix, B::AbstractMatrix) Should I use something else instead of |
Still stuck on this, why does computing gradients work for |
Technically, it only makes sense to define the adjoints for those function where Kronecker provides shortcuts, based on this rule: https://en.wikipedia.org/wiki/Matrix_calculus#Identities_in_differential_form |
Can you provide a MWE for |
Maybe I misunderstood, but doesn't this only provide the |
Resolves #92.