Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix problem with serialization of nested models when component model overload save/restore #960

Merged
merged 4 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "1.1.1"
version = "1.1.2"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
24 changes: 14 additions & 10 deletions src/composition/learning_networks/replace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ the `model` and `args` field values as derived from the provided dictionaries. I
the returned machine is hooked into the new learning network defined by the values of
`newnode_given_old`.

If `serializable=true`, return a serializable copy instead (namely,
`serializable(node.mach)`) and ignore the `newmodel_given_old` dictionary (no model
replacement).
If `serializable=true`, return a serializable copy instead, but make no model replacement.
The `newmodel_given_old` dictionary is still used, but now to look up the concrete model
corresponding to the symbolic one stored in `node`'s machine.

See also [`serializable`](@ref).

Expand All @@ -26,9 +26,10 @@ function machine_replacement(
newnode_given_old,
serializable
)
# the `replace` called here is defined in src/machines.jl:
mach = serializable ? MLJBase.serializable(N.machine) :
replace(N.machine, :model => newmodel_given_old[N.machine.model])
# the `replace` called below is defined in src/machines.jl.
newmodel = newmodel_given_old[N.machine.model]
mach = serializable ? MLJBase.serializable(N.machine, newmodel) :
replace(N.machine, :model => newmodel)
mach.args = Tuple(newnode_given_old[arg] for arg in N.machine.args)
return mach
end
Expand All @@ -38,6 +39,7 @@ end
newnode_given_old,
newmach_given_old,
newmodel_given_old,
serializable,
node::AbstractNode)

**Private method.**
Expand Down Expand Up @@ -86,9 +88,11 @@ const DOC_REPLACE_OPTIONS =
- `copy_unspecified_deeply=true`: If `false`, models or sources not listed for
replacement are identically equal in the original and returned node.

- `serializable=false`: If `true`, all machines in the new network are serializable.
However, all `model` replacements are ignored, and unspecified sources are always
replaced with empty ones.
- `serializable=false`: If `true`, all machines in the new network are made
serializable and the specified model replacements are only used for serialization
purposes: for each pair `s => model` (`s` assumed to be a symbolic model) each
machine with model `s` is replaced with `serializable(mach, model)`. All unspecified
sources are always replaced with empty ones.

"""

Expand Down Expand Up @@ -192,7 +196,7 @@ function _replace(

# Instantiate model dictionary:
model_pairs = filter(collect(pairs)) do pair
first(pair) isa Model
first(pair) isa Model || first(pair) isa Symbol
end
models_ = models(W)
models_to_copy = setdiff(models_, first.(model_pairs))
Expand Down
25 changes: 20 additions & 5 deletions src/composition/models/network_composite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,33 @@ MLJModelInterface.fitted_params(composite::NetworkComposite, signature) =
MLJModelInterface.reporting_operations(::Type{<:NetworkComposite}) = OPERATIONS

# here `fitresult` has type `Signature`.
save(model::NetworkComposite, fitresult) = replace(fitresult, serializable=true)
function save(model::NetworkComposite, fitresult)
# The network includes machines with symbolic models. These machines need to be
# replaced by serializable versions, but we cannot naively use `serializable(mach)`,
# because the absence of the concrete model means this just returns `mach` (because
# `save(::Symbol, fitresult)` returns `fitresult`). We need to use the special
# `serialiable(mach, model)` instead. This is what `replace` below does, because we
# pass it the flag `serializable=true` but we must also pass `symbol =>
# concrete_model` replacements, which we calculate first:

greatest_lower_bound = MLJBase.glb(fitresult)
machines_given_model = MLJBase.machines_given_model(greatest_lower_bound)
atomic_models = keys(machines_given_model)
pairs = [atom => getproperty(model, atom) for atom in atomic_models]

replace(fitresult, pairs...; serializable=true)
end

function MLJModelInterface.restore(model::NetworkComposite, serializable_fitresult)
greatest_lower_bound = MLJBase.glb(serializable_fitresult)
machines_given_model = MLJBase.machines_given_model(greatest_lower_bound)
models = keys(machines_given_model)
atomic_models = keys(machines_given_model)

# the following indirectly mutates `serialiable_fiteresult`, returning it to
# usefulness:
for model in models
for mach in machines_given_model[model]
mach.fitresult = restore(model, mach.fitresult)
for atom in atomic_models
for mach in machines_given_model[atom]
mach.fitresult = MLJBase.restore(getproperty(model, atom), mach.fitresult)
mach.state = 1
end
end
Expand Down
14 changes: 7 additions & 7 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -975,17 +975,17 @@ A machine returned by `serializable` is characterized by the property
See also [`restore!`](@ref), [`MLJBase.save`](@ref).

"""
function serializable(mach::Machine{<:Any, C}; verbosity=1) where C
function serializable(mach::Machine{<:Any, C}, model=mach.model; verbosity=1) where C

isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED)
mach.state == -1 && return mach

# The next line of code makes `serializable` recursive, in the case that `mach.model`
# is a `Composite` model: `save` duplicates the underlying learning network, which
# involves calls to `serializable` on the old machines in the network to create the
# new ones.

isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED)
mach.state == -1 && return mach

serializable_fitresult = save(mach.model, mach.fitresult)
serializable_fitresult = save(model, mach.fitresult)

# Duplication currenty needs to happen in two steps for this to work in case of
# `Composite` models.
Expand Down Expand Up @@ -1017,9 +1017,9 @@ useable form.
For an example see [`serializable`](@ref).

"""
function restore!(mach::Machine)
function restore!(mach::Machine, model=mach.model)
mach.state != -1 && return mach
mach.fitresult = restore(mach.model, mach.fitresult)
mach.fitresult = restore(model, mach.fitresult)
mach.state = 1
return mach
end
Expand Down
74 changes: 74 additions & 0 deletions test/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,80 @@ end
rm(filename)
end

# define a model with non-persistent fitresult:
thing = []
struct EphemeralTransformer <: Unsupervised end
function MLJModelInterface.fit(::EphemeralTransformer, verbosity, X)
view = pointer(thing)
fitresult = (thing, view)
return fitresult, nothing, NamedTuple()
end
function MLJModelInterface.transform(::EphemeralTransformer, fitresult, X)
thing, view = fitresult
return view == pointer(thing) ? X : throw(ErrorException("dead fitresult"))
end
function MLJModelInterface.save(::EphemeralTransformer, fitresult)
thing, _ = fitresult
return thing
end
function MLJModelInterface.restore(::EphemeralTransformer, serialized_fitresult)
view = pointer(thing)
return (thing, view)
end

# commented out code just tests the transformer above has desired properties for testing:

# # test model transforms:
# model = EphemeralTransformer()
# mach = machine(model, 42) |> fit!
# @test MLJBase.transform(mach, 27) == 27

# # direct serialization fails:
# io = IOBuffer()
# serialize(io, mach)
# seekstart(io)
# mach2 = deserialize(io)
# @test_throws ErrorException("dead fitresult") transform(mach2, 42)

@testset "serialization for model with non-persistent fitresult" begin
X = (; x=randn(5))
mach = machine(EphemeralTransformer(), X)
fit!(mach, verbosity=0)
v = MLJBase.transform(mach, X).x
io = IOBuffer()
serialize(io, serializable(mach))
seekstart(io)
mach2 = restore!(deserialize(io))
@test MLJBase.transform(mach2, X).x == v

# using `save`/`machine`:
MLJBase.save(io, mach)
seekstart(io)
mach2 = machine(io)
@test MLJBase.transform(mach2, X).x == v
end

@testset "serialization for model with non-persistent fitresult in pipeline" begin
# https://github.com/JuliaAI/MLJBase.jl/issues/927
X = (; x=randn(5))
pipe = Standardizer |> EphemeralTransformer
X = (; x=randn(5))
mach = machine(pipe, X)
fit!(mach, verbosity=0)
v = MLJBase.transform(mach, X).x
io = IOBuffer()
serialize(io, serializable(mach))
seekstart(io)
mach2 = restore!(deserialize(io))
@test MLJBase.transform(mach2, X).x == v

# using `save`/`machine`:
MLJBase.save(io, mach)
seekstart(io)
mach2 = machine(io)
@test MLJBase.transform(mach2, X).x == v
end

struct ReportingDynamic <: Unsupervised end
MLJBase.fit(::ReportingDynamic, _, X) = nothing, 16, NamedTuple()
MLJBase.transform(::ReportingDynamic,_, X) = (X, (news=42,))
Expand Down
34 changes: 17 additions & 17 deletions test/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,22 +117,22 @@ API.@trait(
[LogLoss(), ], dummy_interval, 1))
end

@everywhere begin
nfolds = 6
nmeasures = 2
func(mach, k) = (
(sleep(MLJBase.PROG_METER_DT*rand(rng)); fill(1:k, nmeasures)),
:fitted_params,
:report,
)
end
@testset_accelerated "dispatch of resources and progress meter" accel begin

@info "Checking progress bars:"

X = (x = [1, ],)
y = [2.0, ]

@everywhere begin
nfolds = 6
nmeasures = 2
func(mach, k) = (
(sleep(MLJBase.PROG_METER_DT*rand(rng)); fill(1:k, nmeasures)),
:fitted_params,
:report,
)
end
mach = machine(ConstantRegressor(), X, y)
if accel isa CPUThreads
result = MLJBase._evaluate!(
Expand Down Expand Up @@ -643,15 +643,15 @@ end

struct DummyResamplingStrategy <: MLJBase.ResamplingStrategy end

@testset_accelerated "custom strategy depending on X, y" accel begin
function MLJBase.train_test_pairs(resampling::DummyResamplingStrategy,
rows, X, y)
train = filter(rows) do j
y[j] == y[1]
function MLJBase.train_test_pairs(resampling::DummyResamplingStrategy,
rows, X, y)
train = filter(rows) do j
y[j] == y[1]
end
test = setdiff(rows, train)
return [(train, test),]
end
test = setdiff(rows, train)
return [(train, test),]
end
@testset_accelerated "custom strategy depending on X, y" accel begin

X = (x = rand(rng,8), )
y = categorical(string.([:x, :y, :x, :x, :y, :x, :x, :y]))
Expand Down
Loading