Skip to content

Commit

Permalink
using @layer instead of @functor from Flux
Browse files Browse the repository at this point in the history
  • Loading branch information
simonmandlik committed Mar 10, 2024
1 parent a21b9b2 commit 5913cac
Show file tree
Hide file tree
Showing 23 changed files with 23 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Combinatorics = "1.0"
DataFrames = "1.6.1"
DataStructures = "0.18.15"
FiniteDifferences = "0.12.31"
Flux = "0.14"
Flux = "0.14.13"
HierarchicalUtils = "2.1.5"
MLUtils = "0.4.4"
MacroTools = "0.5.13"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/examples/dag.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct DagModel{M}
od::Int
end
Flux.@functor DagModel
Flux.@layer :ignore DagModel
nothing # hide
```
Expand Down
2 changes: 1 addition & 1 deletion docs/src/examples/gnn.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ struct GNN{L, M, R}
m::R
end
Flux.@functor GNN
Flux.@layer :ignore GNN
function mpstep(m::GNN, U, bags, n)
n == 0 && return(U)
Expand Down
2 changes: 1 addition & 1 deletion docs/src/manual/custom.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ struct PathModel{T, F} <: AbstractMillModel
path2mill::F
end
Flux.@functor PathModel
Flux.@layer :ignore PathModel
show(io::IO, n::PathModel) = print(io, "PathModel")
NodeType(::Type{<:PathModel}) = LeafNode()
Expand Down
2 changes: 1 addition & 1 deletion src/aggregations/aggregation_stack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ _flatten_agg(a::AbstractAggregation) = [a]

AggregationStack(fs::AbstractAggregation...) = AggregationStack(fs)

Flux.@functor AggregationStack
Flux.@layer :ignore AggregationStack

function (a::AggregationStack)(x::Maybe{AbstractArray}, bags::AbstractBags, args...)
reduce(vcat, (f(x, bags, args...) for f in a.fs))
Expand Down
2 changes: 1 addition & 1 deletion src/aggregations/bagcount.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct BagCount{T<:AbstractAggregation}
a::T
end

Flux.@functor BagCount
Flux.@layer :ignore BagCount

_bagcount(T, bags) = permutedims(log.(one(T) .+ length.(bags)))
ChainRulesCore.@non_differentiable _bagcount(T, bags)
Expand Down
2 changes: 1 addition & 1 deletion src/aggregations/segmented_lse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct SegmentedLSE{V <: AbstractVector{<:AbstractFloat}} <: AbstractAggregation
ρ::V
end

Flux.@functor SegmentedLSE
Flux.@layer :ignore SegmentedLSE

SegmentedLSE(T::Type, d::Int) = SegmentedLSE(zeros(T, d), randn(T, d))
SegmentedLSE(d::Int) = SegmentedLSE(Float32, d)
Expand Down
2 changes: 1 addition & 1 deletion src/aggregations/segmented_max.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct SegmentedMax{V <: AbstractVector{<:Number}} <: AbstractAggregation
ψ::V
end

Flux.@functor SegmentedMax
Flux.@layer :ignore SegmentedMax

SegmentedMax(T::Type, d::Int) = SegmentedMax(zeros(T, d))
SegmentedMax(d::Int) = SegmentedMax(Float32, d)
Expand Down
2 changes: 1 addition & 1 deletion src/aggregations/segmented_mean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct SegmentedMean{V <: AbstractVector{<:Number}} <: AbstractAggregation
ψ::V
end

Flux.@functor SegmentedMean
Flux.@layer :ignore SegmentedMean

SegmentedMean(T::Type, d::Int) = SegmentedMean(zeros(T, d))
SegmentedMean(d::Int) = SegmentedMean(Float32, d)
Expand Down
2 changes: 1 addition & 1 deletion src/aggregations/segmented_pnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct SegmentedPNorm{V <: AbstractVector{<:AbstractFloat}} <: AbstractAggregati
c::V
end

Flux.@functor SegmentedPNorm
Flux.@layer :ignore SegmentedPNorm

SegmentedPNorm(T::Type, d::Int) = SegmentedPNorm(zeros(T, d), randn(T, d), zeros(T, d))
SegmentedPNorm(d::Int) = SegmentedPNorm(Float32, d)
Expand Down
2 changes: 1 addition & 1 deletion src/aggregations/segmented_sum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct SegmentedSum{V <: AbstractVector{<:Number}} <: AbstractAggregation
ψ::V
end

Flux.@functor SegmentedSum
Flux.@layer :ignore SegmentedSum

SegmentedSum(T::Type, d::Int) = SegmentedSum(zeros(T, d))
SegmentedSum(d::Int) = SegmentedSum(Float32, d)
Expand Down
2 changes: 1 addition & 1 deletion src/bagchain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ struct BagChain{T <: Tuple} <: AbstractAggregation
BagChain(xs...) = new{typeof(xs)}(xs)
end

Flux.@functor BagChain
Flux.@layer :ignore BagChain

Flux.@forward BagChain.layers Base.getindex, Base.first, Base.last, Base.lastindex
Flux.@forward BagChain.layers Base.iterate
Expand Down
2 changes: 1 addition & 1 deletion src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ struct BagConv{T, F}
σ::F
end

Flux.@functor BagConv
Flux.@layer :ignore BagConv

function BagConv(d::Int, o::Int, n::Int, σ = identity)
W = (n > 1) ? tuple([randn(o, d) .* sqrt(2.0/(o + d)) for _ in 1:n]...) : randn(o, d) .* sqrt(2.0/(o + d))
Expand Down
2 changes: 1 addition & 1 deletion src/datanodes/arraynode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ See also: [`AbstractMillNode`](@ref), [`ArrayModel`](@ref).
"""
ArrayNode(d::AbstractArray) = ArrayNode(d, nothing)

Flux.@functor ArrayNode
Flux.@layer :ignore ArrayNode

mapdata(f, x::ArrayNode) = ArrayNode(mapdata(f, x.data), x.metadata)

Expand Down
2 changes: 1 addition & 1 deletion src/datanodes/bagnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ See also: [`WeightedBagNode`](@ref), [`AbstractBagNode`](@ref),
BagNode(d::Maybe{AbstractMillNode}, b::AbstractVector, m=nothing) = BagNode(d, bags(b), m)
BagNode(d, b, m=nothing) = BagNode(_arraynode(d), b, m)

Flux.@functor BagNode
Flux.@layer :ignore BagNode

mapdata(f, x::BagNode) = BagNode(mapdata(f, x.data), x.bags, x.metadata)

Expand Down
2 changes: 1 addition & 1 deletion src/datanodes/productnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ See also: [`AbstractProductNode`](@ref), [`AbstractMillNode`](@ref), [`ProductMo
ProductNode(ds, args...) = ProductNode(tuple(ds), args...)
ProductNode(args...; ns...) = ProductNode(NamedTuple(ns), args...)

Flux.@functor ProductNode
Flux.@layer :ignore ProductNode

mapdata(f, x::ProductNode) = ProductNode(map(i -> mapdata(f, i), x.data), x.metadata)

Expand Down
2 changes: 1 addition & 1 deletion src/datanodes/weighted_bagnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ WeightedBagNode(d::Maybe{AbstractMillNode}, b::AbstractVector, weights::Vector,
WeightedBagNode(d, bags(b), weights, metadata)
WeightedBagNode(d, b, w, m=nothing) = WeightedBagNode(_arraynode(d), b, w, m)

Flux.@functor WeightedBagNode
Flux.@layer :ignore WeightedBagNode

mapdata(f, x::WeightedBagNode) = WeightedBagNode(mapdata(f, x.data), x.bags, x.weights, x.metadata)

Expand Down
2 changes: 1 addition & 1 deletion src/modelnodes/arraymodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct ArrayModel{T} <: AbstractMillModel
m::T
end

Flux.@functor ArrayModel
Flux.@layer :ignore ArrayModel

(m::ArrayModel)(x::ArrayNode) = m.m(x.data)

Expand Down
2 changes: 1 addition & 1 deletion src/modelnodes/bagmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct BagModel{T <: AbstractMillModel, A <: Union{AbstractAggregation, BagCount
bm::U
end

Flux.@functor BagModel
Flux.@layer :ignore BagModel

"""
BagModel(im, a, bm=identity)
Expand Down
2 changes: 1 addition & 1 deletion src/modelnodes/lazymodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,6 @@ See also: [`AbstractMillModel`](@ref), [`LazyNode`](@ref), [`Mill.unpack2mill`](
"""
LazyModel(Name::Symbol, m) = LazyModel{Name}(m)

Flux.@functor LazyModel
Flux.@layer :ignore LazyModel

(m::LazyModel{Name})(x::LazyNode{Name}) where {Name} = m.m(unpack2mill(x))
2 changes: 1 addition & 1 deletion src/modelnodes/productmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ struct ProductModel{T<:VecOrTupOrNTup{AbstractMillModel},U} <: AbstractMillModel
end
end

Flux.@functor ProductModel
Flux.@layer :ignore ProductModel

"""
ProductModel(ms, m=identity)
Expand Down
2 changes: 1 addition & 1 deletion src/special_arrays/postimputing_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ PostImputingMatrix(W::AbstractMatrix{T}) where T = PostImputingMatrix(W, zeros(T

Flux.@forward PostImputingMatrix.W Base.size, Base.getindex, Base.setindex!, Base.firstindex, Base.lastindex

Flux.@functor PostImputingMatrix
Flux.@layer :ignore PostImputingMatrix

Base.vcat(As::PostImputingMatrix...) = PostImputingMatrix(vcat((A.W for A in As)...), vcat((A.ψ for A in As)...))
function Base.hcat(As::PostImputingMatrix...)
Expand Down
2 changes: 1 addition & 1 deletion src/special_arrays/preimputing_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ PreImputingMatrix(W::AbstractMatrix{T}) where T = PreImputingMatrix(W, zeros(T,

Flux.@forward PreImputingMatrix.W Base.size, Base.getindex, Base.setindex!, Base.firstindex, Base.lastindex

Flux.@functor PreImputingMatrix
Flux.@layer :ignore PreImputingMatrix

Base.hcat(As::PreImputingMatrix...) = PreImputingMatrix(hcat((A.W for A in As)...), vcat((A.ψ for A in As)...))
function Base.vcat(As::PreImputingMatrix...)
Expand Down

0 comments on commit 5913cac

Please sign in to comment.