Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type instability in Flux.setup #162

Open
Vilin97 opened this issue Oct 10, 2023 · 7 comments
Open

Type instability in Flux.setup #162

Vilin97 opened this issue Oct 10, 2023 · 7 comments

Comments

@Vilin97
Copy link

Vilin97 commented Oct 10, 2023

using Flux

function test_setup(opt, s)
    state = Flux.setup(opt, s)
    return state
end
s = Chain(
        Dense(2 => 100, softsign),
        Dense(100 => 2)
    )
opt = Adam(0.1)
@code_warntype test_setup(opt, s) # type unstable

Output:

MethodInstance for GradientFlows.test_setup(::Adam, ::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
  from test_setup(opt, s) @ GradientFlows c:\Users\Math User\.julia\dev\GradientFlows\src\solvers\sbtm.jl:106
Arguments
  #self#::Core.Const(GradientFlows.test_setup)
  opt::Adam
  s::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Locals
  state::Any
Body::Any
1 ─ %1 = Flux.setup::Core.Const(Flux.Train.setup)
│        (state = (%1)(opt, s))
└──      return state

Julia version 1.9.3 and Flux version 0.14.6:

(@v1.9) pkg> st Flux
Status `C:\Users\Math User\.julia\environments\v1.9\Project.toml`
  [587475ba] Flux v0.14.6
@ToucheSir
Copy link
Member

setup is defined in Optimisers.jl, and it's inherently type unstable because it uses a cache to detect + handle shared parameters. Usually I would mark this as a WONTFIX, but there might be some fancy method and/or newer version of Julia which lets us make setup more type stable.

@ToucheSir ToucheSir transferred this issue from FluxML/Flux.jl Oct 10, 2023
@mcabbott
Copy link
Member

Values from the cache are used when an object x is === some previously seen x. They should therefore always have the same type as what init(rule, x) returns. If this type can be inferred, probably we tell the compiler what to expect, and this may make the whole setup type-stable? Haven't tried though.

@ToucheSir
Copy link
Member

We could use _return_type or friends to do that, yes. One thing I'd like to try to make that easier is to delegate what Functors.CachedWalk currently does to the callback passed into the maps. Then it should be easier to swap in/out different implementations of caching and memoization by simply switching the callback.

@mcabbott
Copy link
Member

mcabbott commented Oct 11, 2023

function _setup(rule, x; cache)
  if haskey(cache, x)
    T1 = Base._return_type(init, Tuple{typeof(rule), typeof(x)})
    T2 = Base._return_type(Leaf, Tuple{typeof(rule), T1})
    return cache[x]::T2
  end
  if isnumeric(x)
    ℓ = Leaf(rule, init(rule, x))
    # as before...

gives

julia> @code_warntype test_setup(opt, s)
MethodInstance for test_setup(::Optimisers.Adam, ::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
  from test_setup(opt, s) @ Main REPL[5]:1
Arguments
  #self#::Core.Const(test_setup)
  opt::Optimisers.Adam
  s::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Locals
  state::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple, NamedTuple}}}
Body::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple, NamedTuple}}}
1%1 = Flux.setup::Core.Const(Flux.Train.setup)
│        (state = (%1)(opt, s))
└──      return state

julia> @code_warntype Optimisers.setup(opt, s)
MethodInstance for Optimisers.setup(::Optimisers.Adam, ::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
  from setup(rule::AbstractRule, model) @ Optimisers ~/.julia/dev/Optimisers/src/interface.jl:29
Arguments
  #self#::Core.Const(Optimisers.setup)
  rule::Optimisers.Adam
  model::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Locals
  tree::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple, NamedTuple}}}
  cache::IdDict{Any, Any}
  msg::String
  kwargs::@NamedTuple{}
  line::Int64
  file::String
  id::Symbol
  logger::Union{Nothing, Base.CoreLogging.AbstractLogger}
  _module::Module
  group::Symbol
  std_level::Base.CoreLogging.LogLevel
  level::Base.CoreLogging.LogLevel
Body::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple, NamedTuple}}}
1 ──       (cache = Optimisers.IdDict())
│    %2  = (:cache,)::Core.Const((:cache,))
│    %3  = Core.apply_type(Core.NamedTuple, %2)::Core.Const(NamedTuple{(:cache,)})
...

@ToucheSir
Copy link
Member

Looks like the inference path _return_type uses might not able to work through the recursion? I wonder if we could use a trick like FluxML/Functors.jl#61 to prevent it from bailing.

@Vilin97
Copy link
Author

Vilin97 commented Oct 18, 2023

In the meantime, would it make sense to add a sentence like This function is type-unstable. to the docstring of setup? If I had seen such a sentence in the docstring, it would have saved me a lot of trouble of discovering it for myself.

@mcabbott
Copy link
Member

would it make sense to add a sentence like "This function is type-unstable." to the docstring of setup?

Yes, probably.

Also to emphasise that the way to deal with this is a function barrier. You run setup exactly once & pass its result to something. If you are running it in a tight loop, you are probably doing it wrong.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants