Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No method _mul! #439

Open
wsmoses opened this issue Dec 31, 2024 · 1 comment
Open

No method _mul! #439

wsmoses opened this issue Dec 31, 2024 · 1 comment

Comments

@wsmoses
Copy link
Member

wsmoses commented Dec 31, 2024

@avik-pal mind taking a look at this?

using LinearAlgebra
using Enzyme
using GaussianDistributions
using Random

## MODEL DEFINITION ###########################################################

struct LinearGaussianProcess{
        T,
        ΦT<:AbstractMatrix{T},
        ΣT<:AbstractMatrix{T}
    }
    ϕ::ΦT
    Σ::ΣT
    function LinearGaussianProcess::ΦT, Σ::ΣT) where {
            T,
            ΦT<:AbstractMatrix{T},
            ΣT<:AbstractMatrix{T}
        }
        @assert size(ϕ,1) == size(Σ,1) == size(Σ,2)
        return new{T, ΦT, ΣT}(ϕ, Σ)
    end
end

struct LinearGaussianModel{
        ΘT,
        TT<:LinearGaussianProcess{ΘT},
        OT<:LinearGaussianProcess{ΘT}
    }
    transition::TT
    observation::OT
    dims::Tuple{Int, Int}
end

## KALMAN FILTER ##############################################################

function kalman_filter(
        model::LinearGaussianModel,
        init_state::Gaussian,
        observations::AbstractVector{T}
    ) where {T}
    log_evidence = zero(T)
    particle = init_state

    A = model.transition.ϕ
    Q = model.transition.Σ

    B = model.observation.ϕ
    R = model.observation.Σ

    for obs in observations
        particle = let μ = particle.μ, Σ = particle.Σ
            Gaussian(A*μ, A*Σ*A' + Q)
        end


        particle, residual, S = GaussianDistributions.correct(
            particle,
            Gaussian([obs], R), B
        )

        log_evidence += GaussianDistributions.logpdf(
            Gaussian(zero(residual), Symmetric(S)),
            residual
        )
    end

    return log_evidence
end

## DEMONSTRATION ##############################################################

# θ should be a single element vector for this demonstration
function build_model::AbstractVector{T}) where {T}
    trans = LinearGaussianProcess(T[1;;], Diagonal(θ))
    obs   = LinearGaussianProcess(T[1;;], Diagonal(T[1]))
    return LinearGaussianModel(trans, obs, (1,1))
end

# log likelihood function
function logℓ::AbstractArray{T}, data) where {T}
    model = build_model(θ)
    init_state = Gaussian(zeros(T, 1), T[1;;])
    return kalman_filter(model, init_state, data)
end

# data generation (with unit covariance)
rng  = MersenneTwister(1234)
data = cumsum(randn(rng, 100)) .+ randn(rng, 100)

# ensure that log likelihood looks stable
logℓ([1.0], data)



using Reactant
# Reactant will automatically upgrade code to use GPU/TPU where available
# For more apples to apples tests lets force CPU
Reactant.set_default_backend("cpu")


ra = Reactant.to_rarray([1.0])

gradfunc = Reactant.@compile Enzyme.gradient(Reverse, logℓ, ra, Const(data))

hits

julia> gradfunc = Reactant.@compile Enzyme.gradient(Reverse, logℓ, ra, Const(data))
ERROR: MethodError: no method matching _mul!(::Vector{Reactant.TracedRNumber{Float64}}, ::Matrix{Reactant.TracedRNumber{Float64}}, ::Vector{Reactant.TracedRNumber{Float64}}, ::Bool, ::Bool)

Closest candidates are:
  _mul!(::AbstractVecOrMat, ::Diagonal, ::AbstractVector, ::Any)
   @ LinearAlgebra ~/git/Enzyme.jl/julia-1.10.5/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:395
  _mul!(::AbstractVecOrMat, ::Union{Bidiagonal, SymTridiagonal, Tridiagonal}, ::AbstractVecOrMat, ::LinearAlgebra.MulAddMul)
   @ LinearAlgebra ~/git/Enzyme.jl/julia-1.10.5/share/julia/stdlib/v1.10/LinearAlgebra/src/bidiag.jl:551
  _mul!(::AbstractMatrix, ::AbstractMatrix, ::Diagonal, ::Any)
   @ LinearAlgebra ~/git/Enzyme.jl/julia-1.10.5/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:399
  ...

Stacktrace:
  [1] #mul!
    @ ~/.julia/packages/Reactant/WudhJ/src/Overlay.jl:133 [inlined]
  [2] mul!(none::Vector{Reactant.TracedRNumber{Float64}}, none::Matrix{Reactant.TracedRNumber{Float64}}, none::Vector{Reactant.TracedRNumber{Float64}}, none::Bool, none::Bool)
    @ Reactant ./<missing>:0
  [3] #mul!
    @ ~/.julia/packages/Reactant/WudhJ/src/Overlay.jl:130 [inlined]
  [4] call_with_reactant(::typeof(mul!), ::Vector{Reactant.TracedRNumber{Float64}}, ::Matrix{Reactant.TracedRNumber{Float64}}, ::Vector{Reactant.TracedRNumber{Float64}}, ::Bool, ::Bool)
    @ Reactant ~/.julia/packages/Reactant/WudhJ/src/utils.jl:0
  [5] mul!(C::Vector{Reactant.TracedRNumber{Float64}}, A::Matrix{Reactant.TracedRNumber{Float64}}, B::Vector{Reactant.TracedRNumber{Float64}})
    @ Reactant ~/.julia/packages/Reactant/WudhJ/src/Overlay.jl:140
  [6] *
    @ ~/git/Enzyme.jl/julia-1.10.5/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:57 [inlined]
  [7] *(none::Matrix{Reactant.TracedRNumber{Float64}}, none::Vector{Reactant.TracedRNumber{Float64}})
    @ Reactant ./<missing>:0
  [8] _any
    @ ./reduce.jl:1219 [inlined]
  [9] any
    @ ./reduce.jl:1235 [inlined]
 [10] TupleOrBottom
    @ ./promotion.jl:482 [inlined]
 [11] promote_op
    @ ./promotion.jl:498 [inlined]
 [12] *
    @ ~/git/Enzyme.jl/julia-1.10.5/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:56 [inlined]
 [13] call_with_reactant(::typeof(*), ::Matrix{Reactant.TracedRNumber{Float64}}, ::Vector{Reactant.TracedRNumber{Float64}})
    @ Reactant ~/.julia/packages/Reactant/WudhJ/src/utils.jl:0
 [14] kalman_filter
    @ ./REPL[7]:17 [inlined]
 [15] logℓ
    @ ./REPL[9]:4 [inlined]
 [16] logℓ(none::Reactant.TracedRArray{Float64, 1}, none::Vector{Float64})
    @ Reactant ./<missing>:0
 [17] logℓ
    @ ./REPL[9]:2 [inlined]
 [18] call_with_reactant(::typeof(logℓ), ::Reactant.TracedRArray{Float64, 1}, ::Vector{Float64})
    @ Reactant ~/.julia/packages/Reactant/WudhJ/src/utils.jl:0
 [19] (::Reactant.TracedUtils.var"#8#18"{Bool, Bool, typeof(logℓ), Tuple{Reactant.TracedRArray{}, Vector{}}, Vector{Union{}}, Tuple{Reactant.TracedRArray{}, Vector{}}})()
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:182
 [20] block!(f::Reactant.TracedUtils.var"#8#18"{Bool, Bool, typeof(logℓ), Tuple{}, Vector{}, Tuple{}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/WudhJ/src/mlir/IR/Block.jl:201
 [21] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:169
 [22] make_mlir_fn
    @ ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:86 [inlined]
 [23] overload_autodiff(::ReverseMode{false, false, FFIABI, false, true}, ::Const{typeof(logℓ)}, ::Type{Active}, ::Duplicated{Reactant.TracedRArray{Float64, 1}}, ::Const{Vector{Float64}})
    @ Reactant ~/.julia/packages/Reactant/WudhJ/src/Interpreter.jl:238
 [24] autodiff(::ReverseMode{false, false, FFIABI, false, true}, ::Const{typeof(logℓ)}, ::Type{Active}, ::Duplicated{Reactant.TracedRArray{Float64, 1}}, ::Const{Vector{Float64}})
    @ Reactant ~/.julia/packages/Reactant/WudhJ/src/Overlay.jl:32
 [25] autodiff
    @ ~/git/Enzyme.jl/src/Enzyme.jl:524 [inlined]
 [26] macro expansion
    @ ~/git/Enzyme.jl/src/sugar.jl:275 [inlined]
 [27] gradient
    @ ~/git/Enzyme.jl/src/sugar.jl:263 [inlined]
 [28] gradient(none::ReverseMode{false, false, FFIABI, false, false}, none::typeof(logℓ), none::Reactant.TracedRArray{Float64, 1}, none::Tuple{Const{Vector{Float64}}})
    @ Reactant ./<missing>:0
 [29] Array
    @ ./boot.jl:477 [inlined]
 [30] IdDict
    @ ./iddict.jl:30 [inlined]
 [31] IdDict
    @ ./iddict.jl:48 [inlined]
 [32] make_zero (repeats 2 times)
    @ ~/git/Enzyme.jl/lib/EnzymeCore/src/EnzymeCore.jl:529 [inlined]
 [33] macro expansion
    @ ~/git/Enzyme.jl/src/sugar.jl:321 [inlined]
 [34] gradient
    @ ~/git/Enzyme.jl/src/sugar.jl:263 [inlined]
 [35] call_with_reactant(::typeof(gradient), ::ReverseMode{false, false, FFIABI, false, false}, ::typeof(logℓ), ::Reactant.TracedRArray{Float64, 1}, ::Const{Vector{Float64}})
    @ Reactant ~/.julia/packages/Reactant/WudhJ/src/utils.jl:0
 [36] (::Reactant.TracedUtils.var"#8#18"{Bool, Bool, typeof(gradient), Tuple{}, Vector{}, Tuple{}})()
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:182
 [37] block!(f::Reactant.TracedUtils.var"#8#18"{Bool, Bool, typeof(gradient), Tuple{}, Vector{}, Tuple{}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/WudhJ/src/mlir/IR/Block.jl:201
 [38] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:169
 [39] make_mlir_fn
    @ ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:86 [inlined]
 [40] #10
    @ ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:319 [inlined]
 [41] block!(f::Reactant.Compiler.var"#10#15"{typeof(gradient), Tuple{ReverseMode{}, typeof(logℓ), ConcreteRArray{}, Const{}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/WudhJ/src/mlir/IR/Block.jl:201
 [42] #9
    @ ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:318 [inlined]
 [43] mmodule!(f::Reactant.Compiler.var"#9#14"{Reactant.MLIR.IR.Module, typeof(gradient), Tuple{ReverseMode{}, typeof(logℓ), ConcreteRArray{}, Const{}}}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/WudhJ/src/mlir/IR/Module.jl:92
 [44] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ReverseMode{…}, typeof(logℓ), ConcreteRArray{…}, Const{…}}; optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:315
 [45] compile_mlir!
    @ ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:314 [inlined]
 [46] (::Reactant.Compiler.var"#32#34"{Bool, typeof(gradient), Tuple{ReverseMode{false, false, FFIABI, false, false}, typeof(logℓ), ConcreteRArray{Float64, 1}, Const{Vector{Float64}}}})()
    @ Reactant.Compiler ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:799
 [47] context!(f::Reactant.Compiler.var"#32#34"{Bool, typeof(gradient), Tuple{ReverseMode{}, typeof(logℓ), ConcreteRArray{}, Const{}}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/WudhJ/src/mlir/IR/Context.jl:76
 [48] compile_xla(f::Function, args::Tuple{ReverseMode{false, false, FFIABI, false, false}, typeof(logℓ), ConcreteRArray{Float64, 1}, Const{Vector{Float64}}}; client::Nothing, optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:796
 [49] compile_xla
    @ ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:791 [inlined]
 [50] compile(f::Function, args::Tuple{ReverseMode{false, false, FFIABI, false, false}, typeof(logℓ), ConcreteRArray{Float64, 1}, Const{Vector{…}}}; client::Nothing, optimize::Bool, sync::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:823
Some type information was truncated. Use `show(err)` to see complete types.
@wsmoses
Copy link
Member Author

wsmoses commented Dec 31, 2024

ideally we should auto upgrade array of traced -> traced of array in the mul, but also this shouldn't fail regardless

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant