Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address some predict/transform type instabilities #969

Merged
merged 2 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/composition/learning_networks/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ See also [`node`](@ref), [`Source`](@ref), [`origins`](@ref),
[`sources`](@ref), [`fit!`](@ref).

"""
struct Node{T<:Union{Machine, Nothing}} <: AbstractNode
struct Node{T<:Union{Machine, Nothing},Oper} <: AbstractNode

operation # eg, `predict` or a static operation, such as `exp`
operation::Oper # eg, `predict` or a static operation, such as `exp`
machine::T # is `nothing` for static operations

# nodes called to get args for `operation(model, ...) ` or
Expand All @@ -43,9 +43,11 @@ struct Node{T<:Union{Machine, Nothing}} <: AbstractNode
# order consistent with extended graph, excluding self
nodes::Vector{AbstractNode}

function Node(operation,
machine::T,
args::AbstractNode...) where T<:Union{Machine, Nothing}
function Node(
operation::Oper,
machine::T,
args::AbstractNode...,
) where {T<:Union{Machine, Nothing}, Oper}

# check the number of arguments:
# if machine === nothing && isempty(args)
Expand All @@ -70,7 +72,7 @@ struct Node{T<:Union{Machine, Nothing}} <: AbstractNode
vcat(nodes_, (nodes(n) for n in machine.args)...) |> unique
end

return new{T}(operation, machine, args, origins_, nodes_)
return new{T,Oper}(operation, machine, args, origins_, nodes_)
end
end

Expand Down
32 changes: 20 additions & 12 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ caches_data_by_default(m) = caches_data_by_default(typeof(m))
caches_data_by_default(::Type) = true
caches_data_by_default(::Type{<:Symbol}) = false

mutable struct Machine{M,C} <: MLJType
mutable struct Machine{M,OM,C} <: MLJType

model::M
old_model # for remembering the model used in last call to `fit!`
old_model::OM # for remembering the model used in last call to `fit!`
fitresult
cache

Expand All @@ -77,8 +77,11 @@ mutable struct Machine{M,C} <: MLJType
function Machine(
model::M, args::AbstractNode...;
cache=caches_data_by_default(model),
) where M
mach = new{M,cache}(model)
) where M
# In the case of symbolic model, machine cannot know the type of model to be fit
# at time of construction:
OM = M == Symbol ? Any : M
mach = new{M,OM,cache}(model)
mach.frozen = false
mach.state = 0
mach.args = args
Expand Down Expand Up @@ -115,7 +118,7 @@ any upstream dependencies in a learning network):
replace(mach, :args => (), :data => (), :data_resampled_data => (), :cache => nothing)

"""
function Base.replace(mach::Machine{<:Any,C}, field_value_pairs::Pair...) where C
function Base.replace(mach::Machine{<:Any,<:Any,C}, field_value_pairs::Pair...) where C
# determined new `model` and `args` and build replacement dictionary:
newfield_given_old = Dict(field_value_pairs) # to be extended
fields_to_be_replaced = keys(newfield_given_old)
Expand Down Expand Up @@ -436,8 +439,8 @@ machines(::Source) = Machine[]

## DISPLAY

_cache_status(::Machine{<:Any,true}) = "caches model-specific representations of data"
_cache_status(::Machine{<:Any,false}) = "does not cache data"
_cache_status(::Machine{<:Any,<:Any,true}) = "caches model-specific representations of data"
_cache_status(::Machine{<:Any,<:Any,false}) = "does not cache data"

function Base.show(io::IO, mach::Machine)
model = mach.model
Expand Down Expand Up @@ -502,8 +505,8 @@ end
# for getting model specific representation of the row-restricted
# training data from a machine, according to the value of the machine
# type parameter `C` (`true` or `false`):
_resampled_data(mach::Machine{<:Any,true}, model, rows) = mach.resampled_data
function _resampled_data(mach::Machine{<:Any,false}, model, rows)
_resampled_data(mach::Machine{<:Any,<:Any,true}, model, rows) = mach.resampled_data
function _resampled_data(mach::Machine{<:Any,<:Any,false}, model, rows)
raw_args = map(N -> N(), mach.args)
data = MMI.reformat(model, raw_args...)
return selectrows(model, rows, data...)
Expand All @@ -518,6 +521,10 @@ err_no_real_model(mach) = ErrorException(
"""
)

err_missing_model(model) = ErrorException(
"Specified `composite` model does not have `:$(model)` as a field."
)

"""
last_model(mach::Machine)

Expand Down Expand Up @@ -605,7 +612,7 @@ more on these lower-level training methods.

"""
function fit_only!(
mach::Machine{<:Any,cache_data};
mach::Machine{<:Any,<:Any,cache_data};
rows=nothing,
verbosity=1,
force=false,
Expand All @@ -628,7 +635,8 @@ function fit_only!(
# `getproperty(composite, mach.model)`:
model = if mach.model isa Symbol
isnothing(composite) && throw(err_no_real_model(mach))
mach.model in propertynames(composite)
mach.model in propertynames(composite) ||
throw(err_missing_model(model))
getproperty(composite, mach.model)
else
mach.model
Expand Down Expand Up @@ -967,7 +975,7 @@ A machine returned by `serializable` is characterized by the property
See also [`restore!`](@ref), [`MLJBase.save`](@ref).

"""
function serializable(mach::Machine{<:Any, C}, model=mach.model; verbosity=1) where C
function serializable(mach::Machine{<:Any,<:Any,C}, model=mach.model; verbosity=1) where C

isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED)
mach.state == -1 && return mach
Expand Down
10 changes: 6 additions & 4 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ for operation in OPERATIONS
operation == :inverse_transform && continue

ex = quote
function $(operation)(mach::Machine{<:Model,false}; rows=:)
function $(operation)(mach::Machine{<:Model,<:Any,false}; rows=:)
# catch deserialized machine with no data:
isempty(mach.args) && throw(err_serialized($operation))
return ($operation)(mach, mach.args[1](rows=rows))
end
function $(operation)(mach::Machine{<:Model,true}; rows=:)
function $(operation)(mach::Machine{<:Model,<:Any,true}; rows=:)
# catch deserialized machine with no data:
isempty(mach.args) && throw(err_serialized($operation))
model = last_model(mach)
Expand All @@ -92,8 +92,10 @@ for operation in OPERATIONS
end

# special case of Static models (no training arguments):
$operation(mach::Machine{<:Static,true}; rows=:) = throw(ERR_ROWS_NOT_ALLOWED)
$operation(mach::Machine{<:Static,false}; rows=:) = throw(ERR_ROWS_NOT_ALLOWED)
$operation(mach::Machine{<:Static,<:Any,true}; rows=:) =
throw(ERR_ROWS_NOT_ALLOWED)
$operation(mach::Machine{<:Static,<:Any,false}; rows=:) =
throw(ERR_ROWS_NOT_ALLOWED)
end
eval(ex)

Expand Down
2 changes: 1 addition & 1 deletion src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ end
@static if VERSION >= v"1.3.0-DEV.573"

# determines if an instantiated machine caches data:
_caches_data(::Machine{M, C}) where {M, C} = C
_caches_data(::Machine{<:Any,<:Any,C}) where C = C

function _evaluate!(func, mach, accel::CPUThreads, nfolds, verbosity)

Expand Down
1 change: 0 additions & 1 deletion test/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ end
X = ones(2, 3)

mach = @test_logs machine(Scale(2))
@test mach isa Machine{Scale, false}
transform(mach, X) # triggers training of `mach`, ie is mutating
@test report(mach) in [nothing, NamedTuple()]
@test isnothing(fitted_params(mach))
Expand Down
Loading