diff --git a/Project.toml b/Project.toml index a590874b..7d54c1f0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "Mill" uuid = "1d0525e4-8992-11e8-313c-e310e1f6ddea" authors = ["Tomas Pevny ", "Simon Mandlik "] -version = "2.10" +version = "2.10.0" [deps] +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" @@ -17,12 +18,12 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" Preferences = "21216c6a-2e73-6563-6e65-726566657250" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] +Accessors = "0.1" ChainRulesCore = "1" Combinatorics = "1.0" DataFrames = "1" @@ -35,7 +36,6 @@ MacroTools = "0.5" OneHotArrays = "0.1, 0.2" PooledArrays = "1" Preferences = "1" -Setfield = "1" julia = "1.9" [extras] @@ -45,4 +45,4 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [targets] -test = [ "BenchmarkTools", "Documenter", "InteractiveUtils", "Random" ] +test = ["BenchmarkTools", "Documenter", "InteractiveUtils", "Random"] diff --git a/docs/Project.toml b/docs/Project.toml index 0950f886..081f0335 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" @@ -10,6 +11,5 @@ HierarchicalUtils = "f9ccea15-0695-44b9-8113-df7c26ae4fa9" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Mill = "1d0525e4-8992-11e8-313c-e310e1f6ddea" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/docs/make.jl b/docs/make.jl index bac04214..e769c111 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,6 @@ using Pkg using Documenter, DocumenterCitations, Literate -using Mill, Flux, Random, SparseArrays, Setfield, HierarchicalUtils +using Mill, Flux, Random, SparseArrays, Accessors, HierarchicalUtils #= Useful resources for writing docs: @@ -53,7 +53,7 @@ function Mill.unpack2mill(ds::LazyNode{:Sentence}) end DocMeta.setdocmeta!(Mill, :DocTestSetup, quote - using Mill, Flux, Random, SparseArrays, Setfield, HierarchicalUtils + using Mill, Flux, Random, SparseArrays, Accessors, HierarchicalUtils ENV["LINES"] = ENV["COLUMNS"] = typemax(Int) end; recursive=true) diff --git a/src/Mill.jl b/src/Mill.jl index ca84c6c7..2518a67e 100644 --- a/src/Mill.jl +++ b/src/Mill.jl @@ -1,5 +1,6 @@ module Mill +using Accessors using ChainRulesCore using Combinatorics using DataFrames @@ -12,14 +13,13 @@ using MacroTools using OneHotArrays using PooledArrays using Preferences -using Setfield using SparseArrays using Statistics using Base: CodeUnits, nameof using ChainRulesCore: NotImplemented, NotImplementedException using HierarchicalUtils: encode, stringify -using Setfield: IdentityLens, PropertyLens, IndexLens, ComposedLens +using Accessors: PropertyLens, IndexLens, ComposedOptic import Base: *, == diff --git a/src/util.jl b/src/util.jl index c5c66bbd..2b695529 100644 --- a/src/util.jl +++ b/src/util.jl @@ -9,7 +9,8 @@ end """ pred_lens(p, n) -Return a `Vector` of `Setfield.Lens`es for accessing all nodes/fields in `n` conforming to predicate `p`. +Return a `Vector` of `Accessors.jl` lenses for accessing all nodes/fields in `n` conforming to +predicate `p`. # Examples ```jldoctest @@ -20,18 +21,35 @@ ProductNode # 2 obs, 16 bytes ╰── ArrayNode(2×2 Array with Int64 elements) # 2 obs, 80 bytes julia> pred_lens(x -> x isa ArrayNode, n) -1-element Vector{Setfield.ComposedLens{Setfield.PropertyLens{:data}, Setfield.IndexLens{Tuple{Int64}}}}: - (@lens _.data[2]) +1-element Vector{Any}: + (@optic _.data[2]) ``` See also: [`list_lens`](@ref), [`find_lens`](@ref), [`findnonempty_lens`](@ref). """ -pred_lens(p::Function, n) = _pred_lens(p, n) +function pred_lens(p::Function, n) + result = Any[] + _pred_lens!(p, n, (), result) + return result +end + +_pred_lens!(p::Function, x, l, result) = p(x) && push!(result, Accessors.opticcompose(l...)) +function _pred_lens!(p::Function, n::T, l, result) where T <: AbstractMillStruct + p(n) && push!(result, Accessors.opticcompose(l...)) + for k in fieldnames(T) + _pred_lens!(p, getproperty(n, k), (l..., PropertyLens{k}()), result) + end +end +function _pred_lens!(p::Function, n::Union{Tuple, NamedTuple}, l, result) + for i in eachindex(n) + _pred_lens!(p, n[i], (l..., IndexLens((i,))), result) + end +end """ list_lens(n) -Return a `Vector` of `Setfield.Lens`es for accessing all nodes/fields in `n`. +Return a `Vector` of `Accessors.jl` lenses for accessing all nodes/fields in `n`. # Examples ```jldoctest @@ -42,16 +60,16 @@ ProductNode # 2 obs, 16 bytes ╰── ArrayNode(2×2 Array with Int64 elements) # 2 obs, 80 bytes julia> list_lens(n) -9-element Vector{Lens}: - (@lens _) - (@lens _.data[1]) - (@lens _.data[1].data) - (@lens _.data[1].bags) - (@lens _.data[1].metadata) - (@lens _.data[2]) - (@lens _.data[2].data) - (@lens _.data[2].metadata) - (@lens _.metadata) +9-element Vector{Any}: + identity (generic function with 1 method) + (@optic _.data[1]) + (@optic _.data[1].data) + (@optic _.data[1].bags) + (@optic _.data[1].metadata) + (@optic _.data[2]) + (@optic _.data[2].data) + (@optic _.data[2].metadata) + (@optic _.metadata) ``` See also: [`pred_lens`](@ref), [`find_lens`](@ref), [`findnonempty_lens`](@ref). @@ -61,7 +79,8 @@ list_lens(n) = pred_lens(t -> true, n) """ findnonempty_lens(n) -Return a `Vector` of `Setfield.Lens`es for accessing all nodes/fields in `n` that have at least one observation. +Return a `Vector` of `Accessors.jl` lenses for accessing all nodes/fields in `n` that contain at +least one observation. # Examples ```jldoctest @@ -72,10 +91,10 @@ ProductNode # 2 obs, 16 bytes ╰── ArrayNode(2×2 Array with Int64 elements) # 2 obs, 80 bytes julia> findnonempty_lens(n) -3-element Vector{Lens}: - (@lens _) - (@lens _.data[1]) - (@lens _.data[2]) +3-element Vector{Any}: + identity (generic function with 1 method) + (@optic _.data[1]) + (@optic _.data[2]) ``` See also: [`pred_lens`](@ref), [`list_lens`](@ref), [`find_lens`](@ref). @@ -85,8 +104,8 @@ findnonempty_lens(n) = pred_lens(t -> t isa AbstractMillNode && numobs(t) > 0, n """ find_lens(n, x) -Return a `Vector` of `Setfield.Lens`es for accessing all nodes/fields in `n` that return `true` when -compared to `x` using `Base.===`. +Return a `Vector` of `Accessors.jl` lenses for accessing all nodes/fields in `n` that return `true` +when compared to `x` using `Base.===`. # Examples ```jldoctest @@ -97,30 +116,19 @@ ProductNode # 2 obs, 16 bytes ╰── ArrayNode(2×2 Array with Int64 elements) # 2 obs, 80 bytes julia> find_lens(n, n.data[1]) -1-element Vector{Setfield.ComposedLens{Setfield.PropertyLens{:data}, Setfield.IndexLens{Tuple{Int64}}}}: - (@lens _.data[1]) +1-element Vector{Any}: + (@optic _.data[1]) ``` See also: [`pred_lens`](@ref), [`list_lens`](@ref), [`findnonempty_lens`](@ref). """ find_lens(n, x) = pred_lens(t -> t ≡ x, n) -_pred_lens(p::Function, n) = p(n) ? [IdentityLens()] : Lens[] -function _pred_lens(p::Function, n::T) where T <: AbstractMillStruct - res = [map(l -> PropertyLens{k}() ∘ l, _pred_lens(p, getproperty(n, k))) for k in fieldnames(T)] - res = vcat(filter(!isempty, res)...) - p(n) ? [IdentityLens(); res] : res -end -function _pred_lens(p::Function, n::Union{Tuple, NamedTuple}) - res = [map(l -> IndexLens(tuple(i)) ∘ l, _pred_lens(p, n[i])) for i in eachindex(n)] - vcat(filter(!isempty, res)...) -end - """ code2lens(n, c) -Convert code `c` from [HierarchicalUtils.jl](@ref) traversal to a `Vector` of `Setfield.Lens` such -that they access each node in tree egal to `n`. +Convert code `c` from [HierarchicalUtils.jl](@ref) traversal to a `Vector` of `Accessors.jl` +lenses such that they access each node in tree `n` egal to node under code `c` in the tree. # Examples ```jldoctest @@ -133,8 +141,8 @@ ProductNode [""] # 2 obs, 16 bytes ╰── ArrayNode(2×2 Array with Int64 elements) ["U"] # 2 obs, 80 bytes julia> code2lens(n, "U") -1-element Vector{Setfield.ComposedLens{Setfield.PropertyLens{:data}, Setfield.IndexLens{Tuple{Int64}}}}: - (@lens _.data[2]) +1-element Vector{Any}: + (@optic _.data[2]) ``` See also: [`lens2code`](@ref). @@ -144,8 +152,8 @@ code2lens(n::AbstractMillStruct, c::AbstractString) = find_lens(n, n[c]) """ lens2code(n, l) -Convert `Setfield.Lens` l to a `Vector` of codes from [HierarchicalUtils.jl](@ref) traversal such -that they access each node in tree egal to `n`. +Convert `Accessors.jl` lens `l` to a `Vector` of codes from [HierarchicalUtils.jl](@ref) traversal +such that they access each node in tree `n` egal to node accessible by lens `l`. # Examples ```jldoctest @@ -157,20 +165,27 @@ ProductNode [""] # 2 obs, 16 bytes │ ╰── ∅ ["M"] ╰── ArrayNode(2×2 Array with Int64 elements) ["U"] # 2 obs, 80 bytes -julia> lens2code(n, (@lens _.data[2])) +julia> lens2code(n, (@optic _.data[2])) 1-element Vector{String}: "U" +julia> lens2code(n, (@optic _.data[∗])) +2-element Vector{String}: + "E" + "U" + ``` See also: [`code2lens`](@ref). """ -lens2code(n::AbstractMillStruct, l::Lens) = HierarchicalUtils.find_traversal(n, get(n, l)) +lens2code(n::AbstractMillStruct, l) = mapreduce(vcat, Accessors.getall(n, l)) do x + HierarchicalUtils.find_traversal(n, x) +end """ model_lens(m, l) -Convert `Setfield.Lens` `l` for a data node to a new lens for accessing the same location in model `m`. +Convert `Accessors.jl` lens `l` for a data node to a new lens for accessing the same location in model `m`. # Examples ```jldoctest @@ -187,26 +202,26 @@ ProductModel ↦ Dense(20 => 10) # 2 arrays, 210 params, 920 bytes │ ╰── ArrayModel(Dense(2 => 10)) # 2 arrays, 30 params, 200 bytes ╰── ArrayModel(Dense(2 => 10)) # 2 arrays, 30 params, 200 bytes -julia> model_lens(m, (@lens _.data[2])) -(@lens _.ms[2]) +julia> model_lens(m, (@optic _.data[2])) +(@optic _.ms[2]) ``` See also: [`data_lens`](@ref). """ -function model_lens(model, lens::ComposedLens) - outerlens = model_lens(model, lens.outer) - outerlens ∘ model_lens(get(model, outerlens), lens.inner) +function model_lens(model, lens::ComposedOptic) + innerlens = model_lens(model, lens.inner) + innerlens ⨟ model_lens(only(getall(model, innerlens)), lens.outer) end -model_lens(::ArrayModel, ::PropertyLens{:data}) = @lens _.m -model_lens(::BagModel, ::PropertyLens{:data}) = @lens _.im -model_lens(::ProductModel, ::PropertyLens{:data}) = @lens _.ms +model_lens(::ArrayModel, ::PropertyLens{:data}) = @optic _.m +model_lens(::BagModel, ::PropertyLens{:data}) = @optic _.im +model_lens(::ProductModel, ::PropertyLens{:data}) = @optic _.ms model_lens(::Union{NamedTuple, Tuple}, lens::IndexLens) = lens -model_lens(::Union{AbstractMillModel, NamedTuple, Tuple}, lens::IdentityLens) = lens +model_lens(::Union{AbstractMillModel, NamedTuple, Tuple}, lens::typeof(identity)) = lens """ data_lens(n, l) -Convert `Setfield.Lens` `l` for a model node to a new lens for accessing the same location in data node `n`. +Convert `Accessors.jl` lens `l` for a model node to a new lens for accessing the same location in data node `n`. # Examples ```jldoctest @@ -222,21 +237,21 @@ ProductModel ↦ Dense(20 => 10) # 2 arrays, 210 params, 920 bytes │ ╰── ArrayModel(Dense(2 => 10)) # 2 arrays, 30 params, 200 bytes ╰── ArrayModel(Dense(2 => 10)) # 2 arrays, 30 params, 200 bytes -julia> data_lens(n, (@lens _.ms[2])) -(@lens _.data[2]) +julia> data_lens(n, (@optic _.ms[2])) +(@optic _.data[2]) ``` See also: [`data_lens`](@ref). """ -function data_lens(ds, lens::ComposedLens) - outerlens = data_lens(ds, lens.outer) - outerlens ∘ data_lens(get(ds, outerlens), lens.inner) +function data_lens(ds, lens::ComposedOptic) + innerlens = data_lens(ds, lens.inner) + innerlens ⨟ data_lens(only(getall(ds, innerlens)), lens.outer) end -data_lens(::ArrayNode, ::PropertyLens{:m}) = @lens _.data -data_lens(::AbstractBagNode, ::PropertyLens{:im}) = @lens _.data -data_lens(::AbstractProductNode, ::PropertyLens{:ms}) = @lens _.data +data_lens(::ArrayNode, ::PropertyLens{:m}) = @optic _.data +data_lens(::AbstractBagNode, ::PropertyLens{:im}) = @optic _.data +data_lens(::AbstractProductNode, ::PropertyLens{:ms}) = @optic _.data data_lens(::Union{NamedTuple, Tuple}, lens::IndexLens) = lens -data_lens(::Union{AbstractMillNode, NamedTuple, Tuple}, lens::IdentityLens) = lens +data_lens(::Union{AbstractMillNode, NamedTuple, Tuple}, lens::typeof(identity)) = lens """ replacein(n, old, new) diff --git a/test/runtests.jl b/test/runtests.jl index 09e20318..76b75a49 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ using Mill: p_map, inv_p_map, r_map, inv_r_map, _bagnorm using Mill: Maybe using Mill: @gradtest, @pgradtest, gradf +using Accessors using Base.Iterators: partition, product using Base: CodeUnits using ChainRulesCore @@ -91,7 +92,7 @@ end @testset "Doctests" begin DocMeta.setdocmeta!(Mill, :DocTestSetup, quote - using Mill, Flux, Random, SparseArrays, Setfield, HierarchicalUtils + using Mill, Flux, Random, SparseArrays, Accessors, HierarchicalUtils # do not shorten prints in doctests ENV["LINES"] = ENV["COLUMNS"] = typemax(Int) end; recursive=true) diff --git a/test/util.jl b/test/util.jl index f9cb68c1..594b22aa 100644 --- a/test/util.jl +++ b/test/util.jl @@ -14,7 +14,7 @@ m = reflectinmodel(x) all_fields = vcat(all_nodes, [md, an1.data, b.bags, an2.data, wb.bags, wb.weights, an3.data]) all_fields = vcat(all_fields, Mill.metadata.(all_nodes)) - @test all(l -> get(x, l) in all_fields, ls) + @test all(l -> only(getall(x, l)) in all_fields, ls) @test all(n -> n in all_fields, [walk(x, t) for t in list_traversal(x)]) ls = list_lens(m) @@ -22,23 +22,23 @@ m = reflectinmodel(x) all_fields = vcat(all_nodes, [m.m, m.ms[1].m, m.ms[1].ms.b.im.m, m.ms[1].ms.b.a, m.ms[1].ms.b.bm, m.ms[1].ms.wb.im.m, m.ms[1].ms.wb.a, m.ms[1].ms.wb.bm, m.ms[2].m]) - @test all(l -> get(m, l) in all_fields, ls) + @test all(l -> only(getall(m, l)) in all_fields, ls) @test all(n -> n in all_fields, [walk(m, t) for t in list_traversal(m)]) end @testset "findnonempty_lens" begin - @test all(numobs.([get(x, l) for l in findnonempty_lens(x)]) .> 0) + @test all(numobs.([only(getall(x, l)) for l in findnonempty_lens(x)]) .> 0) end @testset "find_lens" begin for t in list_traversal(x) ls = find_lens(x, x[t]) - @test all(l -> get(x, l) ≡ x[t], ls) + @test all(l -> only(getall(x, l)) ≡ x[t], ls) end for t in list_traversal(m) ls = find_lens(m, m[t]) - @test all(l -> get(m, l) ≡ m[t], ls) + @test all(l -> only(getall(m, l)) ≡ m[t], ls) end end @@ -76,4 +76,3 @@ end @test m3[t] ≡ m[t] end end -