Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ChainRules with Tagent #13

Open
roflmaostc opened this issue Jan 14, 2024 · 0 comments
Open

ChainRules with Tagent #13

roflmaostc opened this issue Jan 14, 2024 · 0 comments

Comments

@roflmaostc
Copy link
Member

The following fails:

function ChainRulesCore.rrule(as::AngularSpectrum3, field)
    field_and_tuple = as(field) 
    function as_pullback(ȳ)
        f̄ = NoTangent()
        y2 =fill!(as.buffer2, 0)

        # THIS LINE FAILS
        field_new = as.padding ? ∇set_center!(y2, as.buffer2, field, broadcast=true) : y2  
        field_imd = as.p * ifftshift!(as.buffer, field_new, (1, 2)) 
        field_imd .*= conj.(as.HW)
        field_out = fftshift!(as.buffer2, inv(as.p) * field_imd, (1, 2)) 
        field_out_cropped = as.padding ? crop_center(field_out, size(field), return_view=true) : field_out
        return f̄, field_out_cropped 
    end 
    return field_and_tuple, as_pullback
end

function ∇set_center!(dy, arr_large::AbstractArray{T, N}, arr_small::AbstractArray{T1, M};
                     broadcast=false) where {T, T1, M, N}
    @assert N  M "Can't put a higher dimensional array in a lower dimensional one."

    if broadcast == false
        inds = ntuple(i -> begin
                        a, b = get_indices_around_center(size(arr_large, i), size(arr_small, i))
                        a:b
                      end,
                      Val(N))
        arr_large[inds..., ..] .= dy
    else
        inds = ntuple(i -> begin
                        a, b = get_indices_around_center(size(arr_large, i), size(arr_small, i))
                        a:b
                      end,
                      Val(M))
        # THIS LINE fails with broadcasting
        arr_large[inds..., ..] .= dy
    end


    return arr_large
end

with

julia> include("test/angular_spectrum.jl")
typeof(dy) = Tangent{Any, Tuple{Matrix{ComplexF64}, ZeroTangent}}
Test gradient with Finite Differences: Error During Test at /home/fxw/.julia/dev/WaveOpticsPropagation.jl/test/angular_spectrum.jl:3
  Got exception outside of a @test
  DimensionMismatch: array could not be broadcast to match destination
  Stacktrace:
    [1] check_broadcast_shape
      @ ./broadcast.jl:579 [inlined]
    [2] check_broadcast_axes
      @ ./broadcast.jl:582 [inlined]
    [3] instantiate
      @ ./broadcast.jl:309 [inlined]
    [4] materialize!
      @ ./broadcast.jl:914 [inlined]
    [5] materialize!
      @ ./broadcast.jl:911 [inlined]
    [6] ∇set_center!(dy::Tangent{Any, Tuple{Matrix{ComplexF64}, ZeroTangent}}, arr_large::Matrix{ComplexF64}, arr_small::Matrix{ComplexF64}; broadcast::Bool)
      @ WaveOpticsPropagation ~/.julia/dev/WaveOpticsPropagation.jl/src/utils.jl:248
    [7] ∇set_center!
      @ ~/.julia/dev/WaveOpticsPropagation.jl/src/utils.jl:230 [inlined]
    [8] (::WaveOpticsPropagation.var"#as_pullback#166"{WaveOpticsPropagation.AngularSpectrum3{Matrix{ComplexF64}, Float64, FFTW.cFFTWPlan{ComplexF64, -1, true, 2, Tuple{Int64, Int64}}}, Matrix{ComplexF64}})(ȳ::Tangent{Any, Tuple{Matrix{ComplexF64}, ZeroTangent}})
      @ WaveOpticsPropagation ~/.julia/dev/WaveOpticsPropagation.jl/src/angular_spectrum.jl:200
    [9] (::Zygote.ZBack{WaveOpticsPropagation.var"#as_pullback#166"{WaveOpticsPropagation.AngularSpectrum3{Matrix{ComplexF64}, Float64, FFTW.cFFTWPlan{ComplexF64, -1, true, 2, Tuple{Int64, Int64}}}, Matrix{ComplexF64}}})(dy::Tuple{Matrix{ComplexF64}, Nothing})
      @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/chainrules.jl:211
   [10] f_AS
      @ ~/.julia/dev/WaveOpticsPropagation.jl/test/angular_spectrum.jl:15 [inlined]
   [11] (::Zygote.Pullback{Tuple{var"#f_AS#132", Matrix{ComplexF64}}, Any})(Δ::Float64)
      @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
   [12] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#f_AS#132", Matrix{ComplexF64}}, Any}})(Δ::Float64)
      @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:45
   [13] gradient(f::Function, args::Matrix{ComplexF64})
      @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:97
   [14] macro expansion
      @ ~/.julia/dev/WaveOpticsPropagation.jl/test/angular_spectrum.jl:17 [inlined]
   [15] macro expansion
      @ ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [16] macro expansion
      @ ~/.julia/dev/WaveOpticsPropagation.jl/test/angular_spectrum.jl:4 [inlined]
   [17] macro expansion
      @ ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [18] top-level scope
      @ ~/.julia/dev/WaveOpticsPropagation.jl/test/angular_spectrum.jl:3
   [19] include(fname::String)
      @ Base.MainInclude ./client.jl:489
   [20] top-level scope
      @ REPL[21]:1
   [21] top-level scope
      @ ~/.julia/packages/CUDA/rXson/src/initialization.jl:208
   [22] eval
      @ Core ./boot.jl:385 [inlined]
   [23] eval_user_input(ast::Any, backend::REPL.REPLBackend, mod::Module)
      @ REPL ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:150
   [24] repl_backend_loop(backend::REPL.REPLBackend, get_module::Function)
      @ REPL ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:246
   [25] start_repl_backend(backend::REPL.REPLBackend, consumer::Any; get_module::Function)
      @ REPL ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:231
   [26] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool, backend::Any)
      @ REPL ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:389
   [27] run_repl(repl::REPL.AbstractREPL, consumer::Any)
      @ REPL ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:375
   [28] (::Base.var"#1013#1015"{Bool, Bool, Bool})(REPL::Module)
      @ Base ./client.jl:432
   [29] #invokelatest#2
      @ Base ./essentials.jl:887 [inlined]
   [30] invokelatest
      @ Base ./essentials.jl:884 [inlined]
   [31] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool)
      @ Base ./client.jl:416
   [32] exec_options(opts::Base.JLOptions)
      @ Base ./client.jl:333
   [33] _start()
      @ Base ./client.jl:552
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant