From 7ae58213dd4e158bf885df81498b4db2af3a19f5 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 8 Apr 2024 12:55:23 +1200 Subject: [PATCH 1/2] annotate type for old_model field of Machine type oops --- src/machines.jl | 32 ++++++++++++++++++++------------ src/operations.jl | 10 ++++++---- src/resampling.jl | 2 +- test/machines.jl | 1 - 4 files changed, 27 insertions(+), 18 deletions(-) diff --git a/src/machines.jl b/src/machines.jl index 7a6e8bdc..c2a1de67 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -47,10 +47,10 @@ caches_data_by_default(m) = caches_data_by_default(typeof(m)) caches_data_by_default(::Type) = true caches_data_by_default(::Type{<:Symbol}) = false -mutable struct Machine{M,C} <: MLJType +mutable struct Machine{M,OM,C} <: MLJType model::M - old_model # for remembering the model used in last call to `fit!` + old_model::OM # for remembering the model used in last call to `fit!` fitresult cache @@ -77,8 +77,11 @@ mutable struct Machine{M,C} <: MLJType function Machine( model::M, args::AbstractNode...; cache=caches_data_by_default(model), - ) where M - mach = new{M,cache}(model) + ) where M + # In the case of symbolic model, machine cannot know the type of model to be fit + # at time of construction: + OM = M == Symbol ? Any : M + mach = new{M,OM,cache}(model) mach.frozen = false mach.state = 0 mach.args = args @@ -115,7 +118,7 @@ any upstream dependencies in a learning network): replace(mach, :args => (), :data => (), :data_resampled_data => (), :cache => nothing) """ -function Base.replace(mach::Machine{<:Any,C}, field_value_pairs::Pair...) where C +function Base.replace(mach::Machine{<:Any,<:Any,C}, field_value_pairs::Pair...) where C # determined new `model` and `args` and build replacement dictionary: newfield_given_old = Dict(field_value_pairs) # to be extended fields_to_be_replaced = keys(newfield_given_old) @@ -436,8 +439,8 @@ machines(::Source) = Machine[] ## DISPLAY -_cache_status(::Machine{<:Any,true}) = "caches model-specific representations of data" -_cache_status(::Machine{<:Any,false}) = "does not cache data" +_cache_status(::Machine{<:Any,<:Any,true}) = "caches model-specific representations of data" +_cache_status(::Machine{<:Any,<:Any,false}) = "does not cache data" function Base.show(io::IO, mach::Machine) model = mach.model @@ -502,8 +505,8 @@ end # for getting model specific representation of the row-restricted # training data from a machine, according to the value of the machine # type parameter `C` (`true` or `false`): -_resampled_data(mach::Machine{<:Any,true}, model, rows) = mach.resampled_data -function _resampled_data(mach::Machine{<:Any,false}, model, rows) +_resampled_data(mach::Machine{<:Any,<:Any,true}, model, rows) = mach.resampled_data +function _resampled_data(mach::Machine{<:Any,<:Any,false}, model, rows) raw_args = map(N -> N(), mach.args) data = MMI.reformat(model, raw_args...) return selectrows(model, rows, data...) @@ -518,6 +521,10 @@ err_no_real_model(mach) = ErrorException( """ ) +err_missing_model(model) = ErrorException( + "Specified `composite` model does not have `:$(model)` as a field." +) + """ last_model(mach::Machine) @@ -605,7 +612,7 @@ more on these lower-level training methods. """ function fit_only!( - mach::Machine{<:Any,cache_data}; + mach::Machine{<:Any,<:Any,cache_data}; rows=nothing, verbosity=1, force=false, @@ -628,7 +635,8 @@ function fit_only!( # `getproperty(composite, mach.model)`: model = if mach.model isa Symbol isnothing(composite) && throw(err_no_real_model(mach)) - mach.model in propertynames(composite) + mach.model in propertynames(composite) || + throw(err_missing_model(model)) getproperty(composite, mach.model) else mach.model @@ -967,7 +975,7 @@ A machine returned by `serializable` is characterized by the property See also [`restore!`](@ref), [`MLJBase.save`](@ref). """ -function serializable(mach::Machine{<:Any, C}, model=mach.model; verbosity=1) where C +function serializable(mach::Machine{<:Any,<:Any,C}, model=mach.model; verbosity=1) where C isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED) mach.state == -1 && return mach diff --git a/src/operations.jl b/src/operations.jl index 9fab3999..d42689f2 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -74,12 +74,12 @@ for operation in OPERATIONS operation == :inverse_transform && continue ex = quote - function $(operation)(mach::Machine{<:Model,false}; rows=:) + function $(operation)(mach::Machine{<:Model,<:Any,false}; rows=:) # catch deserialized machine with no data: isempty(mach.args) && throw(err_serialized($operation)) return ($operation)(mach, mach.args[1](rows=rows)) end - function $(operation)(mach::Machine{<:Model,true}; rows=:) + function $(operation)(mach::Machine{<:Model,<:Any,true}; rows=:) # catch deserialized machine with no data: isempty(mach.args) && throw(err_serialized($operation)) model = last_model(mach) @@ -92,8 +92,10 @@ for operation in OPERATIONS end # special case of Static models (no training arguments): - $operation(mach::Machine{<:Static,true}; rows=:) = throw(ERR_ROWS_NOT_ALLOWED) - $operation(mach::Machine{<:Static,false}; rows=:) = throw(ERR_ROWS_NOT_ALLOWED) + $operation(mach::Machine{<:Static,<:Any,true}; rows=:) = + throw(ERR_ROWS_NOT_ALLOWED) + $operation(mach::Machine{<:Static,<:Any,false}; rows=:) = + throw(ERR_ROWS_NOT_ALLOWED) end eval(ex) diff --git a/src/resampling.jl b/src/resampling.jl index a4afc2fa..3759e136 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -1106,7 +1106,7 @@ end @static if VERSION >= v"1.3.0-DEV.573" # determines if an instantiated machine caches data: -_caches_data(::Machine{M, C}) where {M, C} = C +_caches_data(::Machine{<:Any,<:Any,C}) where C = C function _evaluate!(func, mach, accel::CPUThreads, nfolds, verbosity) diff --git a/test/machines.jl b/test/machines.jl index c78aa06d..3f67b9b8 100644 --- a/test/machines.jl +++ b/test/machines.jl @@ -272,7 +272,6 @@ end X = ones(2, 3) mach = @test_logs machine(Scale(2)) - @test mach isa Machine{Scale, false} transform(mach, X) # triggers training of `mach`, ie is mutating @test report(mach) in [nothing, NamedTuple()] @test isnothing(fitted_params(mach)) From 190de707c32fe50ea0526797a99c1d2b6429edbc Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 9 Apr 2024 08:41:45 +1200 Subject: [PATCH 2/2] annotate type of operation field in Node type --- src/composition/learning_networks/nodes.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/composition/learning_networks/nodes.jl b/src/composition/learning_networks/nodes.jl index 5ede32aa..0733b211 100644 --- a/src/composition/learning_networks/nodes.jl +++ b/src/composition/learning_networks/nodes.jl @@ -27,9 +27,9 @@ See also [`node`](@ref), [`Source`](@ref), [`origins`](@ref), [`sources`](@ref), [`fit!`](@ref). """ -struct Node{T<:Union{Machine, Nothing}} <: AbstractNode +struct Node{T<:Union{Machine, Nothing},Oper} <: AbstractNode - operation # eg, `predict` or a static operation, such as `exp` + operation::Oper # eg, `predict` or a static operation, such as `exp` machine::T # is `nothing` for static operations # nodes called to get args for `operation(model, ...) ` or @@ -43,9 +43,11 @@ struct Node{T<:Union{Machine, Nothing}} <: AbstractNode # order consistent with extended graph, excluding self nodes::Vector{AbstractNode} - function Node(operation, - machine::T, - args::AbstractNode...) where T<:Union{Machine, Nothing} + function Node( + operation::Oper, + machine::T, + args::AbstractNode..., + ) where {T<:Union{Machine, Nothing}, Oper} # check the number of arguments: # if machine === nothing && isempty(args) @@ -70,7 +72,7 @@ struct Node{T<:Union{Machine, Nothing}} <: AbstractNode vcat(nodes_, (nodes(n) for n in machine.args)...) |> unique end - return new{T}(operation, machine, args, origins_, nodes_) + return new{T,Oper}(operation, machine, args, origins_, nodes_) end end