Skip to content

Commit

Permalink
migrated from Setfield to Accessors
Browse files Browse the repository at this point in the history
  • Loading branch information
simonmandlik committed Feb 7, 2024
1 parent a747257 commit cf187a7
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 79 deletions.
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "Mill"
uuid = "1d0525e4-8992-11e8-313c-e310e1f6ddea"
authors = ["Tomas Pevny <[email protected]>", "Simon Mandlik <[email protected]>"]
version = "2.10"
version = "2.10.0"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Expand All @@ -17,12 +18,12 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Accessors = "0.1"
ChainRulesCore = "1"
Combinatorics = "1.0"
DataFrames = "1"
Expand All @@ -35,7 +36,6 @@ MacroTools = "0.5"
OneHotArrays = "0.1, 0.2"
PooledArrays = "1"
Preferences = "1"
Setfield = "1"
julia = "1.9"

[extras]
Expand All @@ -45,4 +45,4 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[targets]
test = [ "BenchmarkTools", "Documenter", "InteractiveUtils", "Random" ]
test = ["BenchmarkTools", "Documenter", "InteractiveUtils", "Random"]
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
Expand All @@ -10,6 +11,5 @@ HierarchicalUtils = "f9ccea15-0695-44b9-8113-df7c26ae4fa9"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Mill = "1d0525e4-8992-11e8-313c-e310e1f6ddea"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4 changes: 2 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Pkg
using Documenter, DocumenterCitations, Literate
using Mill, Flux, Random, SparseArrays, Setfield, HierarchicalUtils
using Mill, Flux, Random, SparseArrays, Accessors, HierarchicalUtils

#=
Useful resources for writing docs:
Expand Down Expand Up @@ -53,7 +53,7 @@ function Mill.unpack2mill(ds::LazyNode{:Sentence})
end

DocMeta.setdocmeta!(Mill, :DocTestSetup, quote
using Mill, Flux, Random, SparseArrays, Setfield, HierarchicalUtils
using Mill, Flux, Random, SparseArrays, Accessors, HierarchicalUtils
ENV["LINES"] = ENV["COLUMNS"] = typemax(Int)
end; recursive=true)

Expand Down
4 changes: 2 additions & 2 deletions src/Mill.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Mill

using Accessors
using ChainRulesCore
using Combinatorics
using DataFrames
Expand All @@ -12,14 +13,13 @@ using MacroTools
using OneHotArrays
using PooledArrays
using Preferences
using Setfield
using SparseArrays
using Statistics

using Base: CodeUnits, nameof
using ChainRulesCore: NotImplemented, NotImplementedException
using HierarchicalUtils: encode, stringify
using Setfield: IdentityLens, PropertyLens, IndexLens, ComposedLens
using Accessors: PropertyLens, IndexLens, ComposedOptic

import Base: *, ==

Expand Down
141 changes: 78 additions & 63 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ end
"""
pred_lens(p, n)
Return a `Vector` of `Setfield.Lens`es for accessing all nodes/fields in `n` conforming to predicate `p`.
Return a `Vector` of `Accessors.jl` lenses for accessing all nodes/fields in `n` conforming to
predicate `p`.
# Examples
```jldoctest
Expand All @@ -20,18 +21,35 @@ ProductNode # 2 obs, 16 bytes
╰── ArrayNode(2×2 Array with Int64 elements) # 2 obs, 80 bytes
julia> pred_lens(x -> x isa ArrayNode, n)
1-element Vector{Setfield.ComposedLens{Setfield.PropertyLens{:data}, Setfield.IndexLens{Tuple{Int64}}}}:
(@lens _.data[2])
1-element Vector{Any}:
(@optic _.data[2])
```
See also: [`list_lens`](@ref), [`find_lens`](@ref), [`findnonempty_lens`](@ref).
"""
pred_lens(p::Function, n) = _pred_lens(p, n)
function pred_lens(p::Function, n)
result = Any[]
_pred_lens!(p, n, (), result)
return result
end

_pred_lens!(p::Function, x, l, result) = p(x) && push!(result, Accessors.opticcompose(l...))
function _pred_lens!(p::Function, n::T, l, result) where T <: AbstractMillStruct
p(n) && push!(result, Accessors.opticcompose(l...))
for k in fieldnames(T)
_pred_lens!(p, getproperty(n, k), (l..., PropertyLens{k}()), result)
end
end
function _pred_lens!(p::Function, n::Union{Tuple, NamedTuple}, l, result)
for i in eachindex(n)
_pred_lens!(p, n[i], (l..., IndexLens((i,))), result)
end
end

"""
list_lens(n)
Return a `Vector` of `Setfield.Lens`es for accessing all nodes/fields in `n`.
Return a `Vector` of `Accessors.jl` lenses for accessing all nodes/fields in `n`.
# Examples
```jldoctest
Expand All @@ -42,16 +60,16 @@ ProductNode # 2 obs, 16 bytes
╰── ArrayNode(2×2 Array with Int64 elements) # 2 obs, 80 bytes
julia> list_lens(n)
9-element Vector{Lens}:
(@lens _)
(@lens _.data[1])
(@lens _.data[1].data)
(@lens _.data[1].bags)
(@lens _.data[1].metadata)
(@lens _.data[2])
(@lens _.data[2].data)
(@lens _.data[2].metadata)
(@lens _.metadata)
9-element Vector{Any}:
identity (generic function with 1 method)
(@optic _.data[1])
(@optic _.data[1].data)
(@optic _.data[1].bags)
(@optic _.data[1].metadata)
(@optic _.data[2])
(@optic _.data[2].data)
(@optic _.data[2].metadata)
(@optic _.metadata)
```
See also: [`pred_lens`](@ref), [`find_lens`](@ref), [`findnonempty_lens`](@ref).
Expand All @@ -61,7 +79,8 @@ list_lens(n) = pred_lens(t -> true, n)
"""
findnonempty_lens(n)
Return a `Vector` of `Setfield.Lens`es for accessing all nodes/fields in `n` that have at least one observation.
Return a `Vector` of `Accessors.jl` lenses for accessing all nodes/fields in `n` that contain at
least one observation.
# Examples
```jldoctest
Expand All @@ -72,10 +91,10 @@ ProductNode # 2 obs, 16 bytes
╰── ArrayNode(2×2 Array with Int64 elements) # 2 obs, 80 bytes
julia> findnonempty_lens(n)
3-element Vector{Lens}:
(@lens _)
(@lens _.data[1])
(@lens _.data[2])
3-element Vector{Any}:
identity (generic function with 1 method)
(@optic _.data[1])
(@optic _.data[2])
```
See also: [`pred_lens`](@ref), [`list_lens`](@ref), [`find_lens`](@ref).
Expand All @@ -85,8 +104,8 @@ findnonempty_lens(n) = pred_lens(t -> t isa AbstractMillNode && numobs(t) > 0, n
"""
find_lens(n, x)
Return a `Vector` of `Setfield.Lens`es for accessing all nodes/fields in `n` that return `true` when
compared to `x` using `Base.===`.
Return a `Vector` of `Accessors.jl` lenses for accessing all nodes/fields in `n` that return `true`
when compared to `x` using `Base.===`.
# Examples
```jldoctest
Expand All @@ -97,30 +116,19 @@ ProductNode # 2 obs, 16 bytes
╰── ArrayNode(2×2 Array with Int64 elements) # 2 obs, 80 bytes
julia> find_lens(n, n.data[1])
1-element Vector{Setfield.ComposedLens{Setfield.PropertyLens{:data}, Setfield.IndexLens{Tuple{Int64}}}}:
(@lens _.data[1])
1-element Vector{Any}:
(@optic _.data[1])
```
See also: [`pred_lens`](@ref), [`list_lens`](@ref), [`findnonempty_lens`](@ref).
"""
find_lens(n, x) = pred_lens(t -> t x, n)

_pred_lens(p::Function, n) = p(n) ? [IdentityLens()] : Lens[]
function _pred_lens(p::Function, n::T) where T <: AbstractMillStruct
res = [map(l -> PropertyLens{k}() l, _pred_lens(p, getproperty(n, k))) for k in fieldnames(T)]
res = vcat(filter(!isempty, res)...)
p(n) ? [IdentityLens(); res] : res
end
function _pred_lens(p::Function, n::Union{Tuple, NamedTuple})
res = [map(l -> IndexLens(tuple(i)) l, _pred_lens(p, n[i])) for i in eachindex(n)]
vcat(filter(!isempty, res)...)
end

"""
code2lens(n, c)
Convert code `c` from [HierarchicalUtils.jl](@ref) traversal to a `Vector` of `Setfield.Lens` such
that they access each node in tree egal to `n`.
Convert code `c` from [HierarchicalUtils.jl](@ref) traversal to a `Vector` of `Accessors.jl`
lenses such that they access each node in tree `n` egal to node under code `c` in the tree.
# Examples
```jldoctest
Expand All @@ -133,8 +141,8 @@ ProductNode [""] # 2 obs, 16 bytes
╰── ArrayNode(2×2 Array with Int64 elements) ["U"] # 2 obs, 80 bytes
julia> code2lens(n, "U")
1-element Vector{Setfield.ComposedLens{Setfield.PropertyLens{:data}, Setfield.IndexLens{Tuple{Int64}}}}:
(@lens _.data[2])
1-element Vector{Any}:
(@optic _.data[2])
```
See also: [`lens2code`](@ref).
Expand All @@ -144,8 +152,8 @@ code2lens(n::AbstractMillStruct, c::AbstractString) = find_lens(n, n[c])
"""
lens2code(n, l)
Convert `Setfield.Lens` l to a `Vector` of codes from [HierarchicalUtils.jl](@ref) traversal such
that they access each node in tree egal to `n`.
Convert `Accessors.jl` lens `l` to a `Vector` of codes from [HierarchicalUtils.jl](@ref) traversal
such that they access each node in tree `n` egal to node accessible by lens `l`.
# Examples
```jldoctest
Expand All @@ -157,20 +165,27 @@ ProductNode [""] # 2 obs, 16 bytes
│ ╰── ∅ ["M"]
╰── ArrayNode(2×2 Array with Int64 elements) ["U"] # 2 obs, 80 bytes
julia> lens2code(n, (@lens _.data[2]))
julia> lens2code(n, (@optic _.data[2]))
1-element Vector{String}:
"U"
julia> lens2code(n, (@optic _.data[∗]))
2-element Vector{String}:
"E"
"U"
```
See also: [`code2lens`](@ref).
"""
lens2code(n::AbstractMillStruct, l::Lens) = HierarchicalUtils.find_traversal(n, get(n, l))
lens2code(n::AbstractMillStruct, l) = mapreduce(vcat, Accessors.getall(n, l)) do x
HierarchicalUtils.find_traversal(n, x)
end

"""
model_lens(m, l)
Convert `Setfield.Lens` `l` for a data node to a new lens for accessing the same location in model `m`.
Convert `Accessors.jl` lens `l` for a data node to a new lens for accessing the same location in model `m`.
# Examples
```jldoctest
Expand All @@ -187,26 +202,26 @@ ProductModel ↦ Dense(20 => 10) # 2 arrays, 210 params, 920 bytes
│ ╰── ArrayModel(Dense(2 => 10)) # 2 arrays, 30 params, 200 bytes
╰── ArrayModel(Dense(2 => 10)) # 2 arrays, 30 params, 200 bytes
julia> model_lens(m, (@lens _.data[2]))
(@lens _.ms[2])
julia> model_lens(m, (@optic _.data[2]))
(@optic _.ms[2])
```
See also: [`data_lens`](@ref).
"""
function model_lens(model, lens::ComposedLens)
outerlens = model_lens(model, lens.outer)
outerlens model_lens(get(model, outerlens), lens.inner)
function model_lens(model, lens::ComposedOptic)
innerlens = model_lens(model, lens.inner)
innerlens model_lens(only(getall(model, innerlens)), lens.outer)
end
model_lens(::ArrayModel, ::PropertyLens{:data}) = @lens _.m
model_lens(::BagModel, ::PropertyLens{:data}) = @lens _.im
model_lens(::ProductModel, ::PropertyLens{:data}) = @lens _.ms
model_lens(::ArrayModel, ::PropertyLens{:data}) = @optic _.m
model_lens(::BagModel, ::PropertyLens{:data}) = @optic _.im
model_lens(::ProductModel, ::PropertyLens{:data}) = @optic _.ms
model_lens(::Union{NamedTuple, Tuple}, lens::IndexLens) = lens
model_lens(::Union{AbstractMillModel, NamedTuple, Tuple}, lens::IdentityLens) = lens
model_lens(::Union{AbstractMillModel, NamedTuple, Tuple}, lens::typeof(identity)) = lens

"""
data_lens(n, l)
Convert `Setfield.Lens` `l` for a model node to a new lens for accessing the same location in data node `n`.
Convert `Accessors.jl` lens `l` for a model node to a new lens for accessing the same location in data node `n`.
# Examples
```jldoctest
Expand All @@ -222,21 +237,21 @@ ProductModel ↦ Dense(20 => 10) # 2 arrays, 210 params, 920 bytes
│ ╰── ArrayModel(Dense(2 => 10)) # 2 arrays, 30 params, 200 bytes
╰── ArrayModel(Dense(2 => 10)) # 2 arrays, 30 params, 200 bytes
julia> data_lens(n, (@lens _.ms[2]))
(@lens _.data[2])
julia> data_lens(n, (@optic _.ms[2]))
(@optic _.data[2])
```
See also: [`data_lens`](@ref).
"""
function data_lens(ds, lens::ComposedLens)
outerlens = data_lens(ds, lens.outer)
outerlens data_lens(get(ds, outerlens), lens.inner)
function data_lens(ds, lens::ComposedOptic)
innerlens = data_lens(ds, lens.inner)
innerlens data_lens(only(getall(ds, innerlens)), lens.outer)
end
data_lens(::ArrayNode, ::PropertyLens{:m}) = @lens _.data
data_lens(::AbstractBagNode, ::PropertyLens{:im}) = @lens _.data
data_lens(::AbstractProductNode, ::PropertyLens{:ms}) = @lens _.data
data_lens(::ArrayNode, ::PropertyLens{:m}) = @optic _.data
data_lens(::AbstractBagNode, ::PropertyLens{:im}) = @optic _.data
data_lens(::AbstractProductNode, ::PropertyLens{:ms}) = @optic _.data
data_lens(::Union{NamedTuple, Tuple}, lens::IndexLens) = lens
data_lens(::Union{AbstractMillNode, NamedTuple, Tuple}, lens::IdentityLens) = lens
data_lens(::Union{AbstractMillNode, NamedTuple, Tuple}, lens::typeof(identity)) = lens

"""
replacein(n, old, new)
Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Mill: p_map, inv_p_map, r_map, inv_r_map, _bagnorm
using Mill: Maybe
using Mill: @gradtest, @pgradtest, gradf

using Accessors
using Base.Iterators: partition, product
using Base: CodeUnits
using ChainRulesCore
Expand Down Expand Up @@ -91,7 +92,7 @@ end

@testset "Doctests" begin
DocMeta.setdocmeta!(Mill, :DocTestSetup, quote
using Mill, Flux, Random, SparseArrays, Setfield, HierarchicalUtils
using Mill, Flux, Random, SparseArrays, Accessors, HierarchicalUtils
# do not shorten prints in doctests
ENV["LINES"] = ENV["COLUMNS"] = typemax(Int)
end; recursive=true)
Expand Down
Loading

0 comments on commit cf187a7

Please sign in to comment.