diff --git a/Project.toml b/Project.toml index aa99534c..d67959e5 100644 --- a/Project.toml +++ b/Project.toml @@ -29,7 +29,7 @@ Combinatorics = "1.0" DataFrames = "1.6.1" DataStructures = "0.18.15" FiniteDifferences = "0.12.31" -Flux = "0.14" +Flux = "0.14.13" HierarchicalUtils = "2.1.5" MLUtils = "0.4.4" MacroTools = "0.5.13" diff --git a/docs/src/examples/dag.md b/docs/src/examples/dag.md index ec18e915..a62b5e1b 100644 --- a/docs/src/examples/dag.md +++ b/docs/src/examples/dag.md @@ -28,7 +28,7 @@ struct DagModel{M} od::Int end -Flux.@functor DagModel +Flux.@layer :ignore DagModel nothing # hide ``` diff --git a/docs/src/examples/gnn.md b/docs/src/examples/gnn.md index e25c8c4e..c59c99d6 100644 --- a/docs/src/examples/gnn.md +++ b/docs/src/examples/gnn.md @@ -66,7 +66,7 @@ struct GNN{L, M, R} m::R end -Flux.@functor GNN +Flux.@layer :ignore GNN function mpstep(m::GNN, U, bags, n) n == 0 && return(U) diff --git a/docs/src/manual/custom.md b/docs/src/manual/custom.md index 01dfc6b2..9a0e1989 100644 --- a/docs/src/manual/custom.md +++ b/docs/src/manual/custom.md @@ -111,7 +111,7 @@ struct PathModel{T, F} <: AbstractMillModel path2mill::F end -Flux.@functor PathModel +Flux.@layer :ignore PathModel show(io::IO, n::PathModel) = print(io, "PathModel") NodeType(::Type{<:PathModel}) = LeafNode() diff --git a/src/aggregations/aggregation_stack.jl b/src/aggregations/aggregation_stack.jl index 5fd7babf..f08bd572 100644 --- a/src/aggregations/aggregation_stack.jl +++ b/src/aggregations/aggregation_stack.jl @@ -65,7 +65,7 @@ _flatten_agg(a::AbstractAggregation) = [a] AggregationStack(fs::AbstractAggregation...) = AggregationStack(fs) -Flux.@functor AggregationStack +Flux.@layer :ignore AggregationStack function (a::AggregationStack)(x::Maybe{AbstractArray}, bags::AbstractBags, args...) reduce(vcat, (f(x, bags, args...) for f in a.fs)) diff --git a/src/aggregations/bagcount.jl b/src/aggregations/bagcount.jl index c372eaf6..b8a7470a 100644 --- a/src/aggregations/bagcount.jl +++ b/src/aggregations/bagcount.jl @@ -49,7 +49,7 @@ struct BagCount{T<:AbstractAggregation} a::T end -Flux.@functor BagCount +Flux.@layer :ignore BagCount _bagcount(T, bags) = permutedims(log.(one(T) .+ length.(bags))) ChainRulesCore.@non_differentiable _bagcount(T, bags) diff --git a/src/aggregations/segmented_lse.jl b/src/aggregations/segmented_lse.jl index bf7459d5..9578c379 100644 --- a/src/aggregations/segmented_lse.jl +++ b/src/aggregations/segmented_lse.jl @@ -18,7 +18,7 @@ struct SegmentedLSE{V <: AbstractVector{<:AbstractFloat}} <: AbstractAggregation ρ::V end -Flux.@functor SegmentedLSE +Flux.@layer :ignore SegmentedLSE SegmentedLSE(T::Type, d::Int) = SegmentedLSE(zeros(T, d), randn(T, d)) SegmentedLSE(d::Int) = SegmentedLSE(Float32, d) diff --git a/src/aggregations/segmented_max.jl b/src/aggregations/segmented_max.jl index 046a374c..1cf63b65 100644 --- a/src/aggregations/segmented_max.jl +++ b/src/aggregations/segmented_max.jl @@ -16,7 +16,7 @@ struct SegmentedMax{V <: AbstractVector{<:Number}} <: AbstractAggregation ψ::V end -Flux.@functor SegmentedMax +Flux.@layer :ignore SegmentedMax SegmentedMax(T::Type, d::Int) = SegmentedMax(zeros(T, d)) SegmentedMax(d::Int) = SegmentedMax(Float32, d) diff --git a/src/aggregations/segmented_mean.jl b/src/aggregations/segmented_mean.jl index 613e1009..71298807 100644 --- a/src/aggregations/segmented_mean.jl +++ b/src/aggregations/segmented_mean.jl @@ -16,7 +16,7 @@ struct SegmentedMean{V <: AbstractVector{<:Number}} <: AbstractAggregation ψ::V end -Flux.@functor SegmentedMean +Flux.@layer :ignore SegmentedMean SegmentedMean(T::Type, d::Int) = SegmentedMean(zeros(T, d)) SegmentedMean(d::Int) = SegmentedMean(Float32, d) diff --git a/src/aggregations/segmented_pnorm.jl b/src/aggregations/segmented_pnorm.jl index a04a270a..8af9da91 100644 --- a/src/aggregations/segmented_pnorm.jl +++ b/src/aggregations/segmented_pnorm.jl @@ -19,7 +19,7 @@ struct SegmentedPNorm{V <: AbstractVector{<:AbstractFloat}} <: AbstractAggregati c::V end -Flux.@functor SegmentedPNorm +Flux.@layer :ignore SegmentedPNorm SegmentedPNorm(T::Type, d::Int) = SegmentedPNorm(zeros(T, d), randn(T, d), zeros(T, d)) SegmentedPNorm(d::Int) = SegmentedPNorm(Float32, d) diff --git a/src/aggregations/segmented_sum.jl b/src/aggregations/segmented_sum.jl index fe171522..d1738f58 100644 --- a/src/aggregations/segmented_sum.jl +++ b/src/aggregations/segmented_sum.jl @@ -16,7 +16,7 @@ struct SegmentedSum{V <: AbstractVector{<:Number}} <: AbstractAggregation ψ::V end -Flux.@functor SegmentedSum +Flux.@layer :ignore SegmentedSum SegmentedSum(T::Type, d::Int) = SegmentedSum(zeros(T, d)) SegmentedSum(d::Int) = SegmentedSum(Float32, d) diff --git a/src/bagchain.jl b/src/bagchain.jl index 140af8df..df668c07 100644 --- a/src/bagchain.jl +++ b/src/bagchain.jl @@ -3,7 +3,7 @@ struct BagChain{T <: Tuple} <: AbstractAggregation BagChain(xs...) = new{typeof(xs)}(xs) end -Flux.@functor BagChain +Flux.@layer :ignore BagChain Flux.@forward BagChain.layers Base.getindex, Base.first, Base.last, Base.lastindex Flux.@forward BagChain.layers Base.iterate diff --git a/src/conv.jl b/src/conv.jl index 107f2d13..b985317d 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -210,7 +210,7 @@ struct BagConv{T, F} σ::F end -Flux.@functor BagConv +Flux.@layer :ignore BagConv function BagConv(d::Int, o::Int, n::Int, σ = identity) W = (n > 1) ? tuple([randn(o, d) .* sqrt(2.0/(o + d)) for _ in 1:n]...) : randn(o, d) .* sqrt(2.0/(o + d)) diff --git a/src/datanodes/arraynode.jl b/src/datanodes/arraynode.jl index 13d25920..dd779089 100644 --- a/src/datanodes/arraynode.jl +++ b/src/datanodes/arraynode.jl @@ -28,7 +28,7 @@ See also: [`AbstractMillNode`](@ref), [`ArrayModel`](@ref). """ ArrayNode(d::AbstractArray) = ArrayNode(d, nothing) -Flux.@functor ArrayNode +Flux.@layer :ignore ArrayNode mapdata(f, x::ArrayNode) = ArrayNode(mapdata(f, x.data), x.metadata) diff --git a/src/datanodes/bagnode.jl b/src/datanodes/bagnode.jl index 4691acae..596d669b 100644 --- a/src/datanodes/bagnode.jl +++ b/src/datanodes/bagnode.jl @@ -48,7 +48,7 @@ See also: [`WeightedBagNode`](@ref), [`AbstractBagNode`](@ref), BagNode(d::Maybe{AbstractMillNode}, b::AbstractVector, m=nothing) = BagNode(d, bags(b), m) BagNode(d, b, m=nothing) = BagNode(_arraynode(d), b, m) -Flux.@functor BagNode +Flux.@layer :ignore BagNode mapdata(f, x::BagNode) = BagNode(mapdata(f, x.data), x.bags, x.metadata) diff --git a/src/datanodes/productnode.jl b/src/datanodes/productnode.jl index 1cac9044..1f9a037c 100644 --- a/src/datanodes/productnode.jl +++ b/src/datanodes/productnode.jl @@ -59,7 +59,7 @@ See also: [`AbstractProductNode`](@ref), [`AbstractMillNode`](@ref), [`ProductMo ProductNode(ds, args...) = ProductNode(tuple(ds), args...) ProductNode(args...; ns...) = ProductNode(NamedTuple(ns), args...) -Flux.@functor ProductNode +Flux.@layer :ignore ProductNode mapdata(f, x::ProductNode) = ProductNode(map(i -> mapdata(f, i), x.data), x.metadata) diff --git a/src/datanodes/weighted_bagnode.jl b/src/datanodes/weighted_bagnode.jl index 31dbb69e..6fea405a 100644 --- a/src/datanodes/weighted_bagnode.jl +++ b/src/datanodes/weighted_bagnode.jl @@ -46,7 +46,7 @@ WeightedBagNode(d::Maybe{AbstractMillNode}, b::AbstractVector, weights::Vector, WeightedBagNode(d, bags(b), weights, metadata) WeightedBagNode(d, b, w, m=nothing) = WeightedBagNode(_arraynode(d), b, w, m) -Flux.@functor WeightedBagNode +Flux.@layer :ignore WeightedBagNode mapdata(f, x::WeightedBagNode) = WeightedBagNode(mapdata(f, x.data), x.bags, x.weights, x.metadata) diff --git a/src/modelnodes/arraymodel.jl b/src/modelnodes/arraymodel.jl index 067cbf0b..5568446c 100644 --- a/src/modelnodes/arraymodel.jl +++ b/src/modelnodes/arraymodel.jl @@ -34,7 +34,7 @@ struct ArrayModel{T} <: AbstractMillModel m::T end -Flux.@functor ArrayModel +Flux.@layer :ignore ArrayModel (m::ArrayModel)(x::ArrayNode) = m.m(x.data) diff --git a/src/modelnodes/bagmodel.jl b/src/modelnodes/bagmodel.jl index 7eb1b908..abe70233 100644 --- a/src/modelnodes/bagmodel.jl +++ b/src/modelnodes/bagmodel.jl @@ -37,7 +37,7 @@ struct BagModel{T <: AbstractMillModel, A <: Union{AbstractAggregation, BagCount bm::U end -Flux.@functor BagModel +Flux.@layer :ignore BagModel """ BagModel(im, a, bm=identity) diff --git a/src/modelnodes/lazymodel.jl b/src/modelnodes/lazymodel.jl index 1692657f..014b0136 100644 --- a/src/modelnodes/lazymodel.jl +++ b/src/modelnodes/lazymodel.jl @@ -71,6 +71,6 @@ See also: [`AbstractMillModel`](@ref), [`LazyNode`](@ref), [`Mill.unpack2mill`]( """ LazyModel(Name::Symbol, m) = LazyModel{Name}(m) -Flux.@functor LazyModel +Flux.@layer :ignore LazyModel (m::LazyModel{Name})(x::LazyNode{Name}) where {Name} = m.m(unpack2mill(x)) diff --git a/src/modelnodes/productmodel.jl b/src/modelnodes/productmodel.jl index 0dc97ab7..3dfa3aa9 100644 --- a/src/modelnodes/productmodel.jl +++ b/src/modelnodes/productmodel.jl @@ -54,7 +54,7 @@ struct ProductModel{T<:VecOrTupOrNTup{AbstractMillModel},U} <: AbstractMillModel end end -Flux.@functor ProductModel +Flux.@layer :ignore ProductModel """ ProductModel(ms, m=identity) diff --git a/src/special_arrays/postimputing_matrix.jl b/src/special_arrays/postimputing_matrix.jl index 10ee0172..274429ba 100644 --- a/src/special_arrays/postimputing_matrix.jl +++ b/src/special_arrays/postimputing_matrix.jl @@ -56,7 +56,7 @@ PostImputingMatrix(W::AbstractMatrix{T}) where T = PostImputingMatrix(W, zeros(T Flux.@forward PostImputingMatrix.W Base.size, Base.getindex, Base.setindex!, Base.firstindex, Base.lastindex -Flux.@functor PostImputingMatrix +Flux.@layer :ignore PostImputingMatrix Base.vcat(As::PostImputingMatrix...) = PostImputingMatrix(vcat((A.W for A in As)...), vcat((A.ψ for A in As)...)) function Base.hcat(As::PostImputingMatrix...) diff --git a/src/special_arrays/preimputing_matrix.jl b/src/special_arrays/preimputing_matrix.jl index 827b4470..83362063 100644 --- a/src/special_arrays/preimputing_matrix.jl +++ b/src/special_arrays/preimputing_matrix.jl @@ -51,7 +51,7 @@ PreImputingMatrix(W::AbstractMatrix{T}) where T = PreImputingMatrix(W, zeros(T, Flux.@forward PreImputingMatrix.W Base.size, Base.getindex, Base.setindex!, Base.firstindex, Base.lastindex -Flux.@functor PreImputingMatrix +Flux.@layer :ignore PreImputingMatrix Base.hcat(As::PreImputingMatrix...) = PreImputingMatrix(hcat((A.W for A in As)...), vcat((A.ψ for A in As)...)) function Base.vcat(As::PreImputingMatrix...)