Skip to content

Commit

Permalink
Merge pull request #553 from alan-turing-institute/properties
Browse files Browse the repository at this point in the history
Make MLJType `show` method, and `fitted_params`/`report` for composites, property-based instead of field-based
  • Loading branch information
ablaom authored May 15, 2021
2 parents 4ed3923 + 62228cd commit 4e857fc
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 128 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.18.5"
version = "0.18.6"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down Expand Up @@ -33,14 +33,14 @@ ComputationalResources = "0.3"
Distributions = "0.22, 0.23, 0.24, 0.25"
InvertedIndices = "1"
LossFunctions = "0.5, 0.6"
MLJModelInterface = "0.4.1, 1.0"
MLJModelInterface = "0.4.1, 1.1"
MLJScientificTypes = "0.4.1"
Missings = "0.4, 1"
OrderedCollections = "1.1"
Parameters = "0.12"
PrettyTables = "1"
ProgressMeter = "1.3"
StatisticalTraits = "0.1.1, 1.0"
StatisticalTraits = "1.1"
StatsBase = "0.32, 0.33"
Tables = "0.2, 1.0"
julia = "1"
Expand Down
36 changes: 20 additions & 16 deletions src/composition/learning_networks/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,23 +157,24 @@ MLJModelInterface.fitted_params(mach::Machine{<:Surrogate}) =

## CONSTRUCTING THE RETURN VALUE FOR A COMPOSITE FIT METHOD

# Identify which fields of `model` have, as values, a model in the
# Identify which properties of `model` have, as values, a model in the
# learning network wrapped by `mach`, and check that no two such
# fields have have identical values (#377). Return the field name
# properties have have identical values (#377). Return the property name
# associated with each model in the network (in the order appearing in
# `models(glb(mach))`) using `nothing` when the model is not
# associated with any field.
function network_model_names(model::M, mach::Machine{<:Surrogate}) where M<:Model
# associated with any property.
function network_model_names(model::M,
mach::Machine{<:Surrogate}) where M<:Model

signature = mach.fitresult
network_model_ids = objectid.(MLJBase.models(glb(mach)))

names = fieldnames(M)
names = propertynames(model)

# intialize dict to detect duplicity a la #377:
name_given_id = Dict{UInt64,Vector{Symbol}}()

# identify location of fields whose values are models in the
# identify location of properties whose values are models in the
# learning network, and build name_given_id:
for name in names
id = objectid(getproperty(model, name))
Expand All @@ -193,16 +194,18 @@ function network_model_names(model::M, mach::Machine{<:Surrogate}) where M<:Mode
if !no_duplicates
for (id, name) in name_given_id
if length(name) > 1
@error "The fields $name of $model have identical model "*
@error "The hyperparameters $name of "*
"$model have identical model "*
"instances as values. "
end
end
throw(ArgumentError(
"Two distinct fields of a composite model that are both "*
"Two distinct hyper-parameters of a "*
"composite model that are both "*
"associated with models in the underlying learning "*
"network (eg, any two fields of a `@pipeline` model) "*
"network (eg, any two components of a `@pipeline` model) "*
"cannot have identical values, although they can be `==` "*
"(corresponding nested fields are `==`). "*
"(corresponding nested properties are `==`). "*
"Consider constructing instances "*
"separately or use `deepcopy`. "))
end
Expand Down Expand Up @@ -231,10 +234,11 @@ composite models using `@pipeline` or `@from_network`.
For usage, see the example given below. Specificlly, the call does the
following:
- Determines which fields of `model` point to model instances in the
learning network wrapped by `mach`, for recording in an object
called `cache`, for passing onto the MLJ logic that handles smart
updating (namely, an `MLJBase.update` fallback for composite models).
- Determines which hyper-parameters of `model` point to model
instances in the learning network wrapped by `mach`, for recording
in an object called `cache`, for passing onto the MLJ logic that
handles smart updating (namely, an `MLJBase.update` fallback for
composite models).
- Calls `fit!(mach, verbosity=verbosity)`.
Expand Down Expand Up @@ -290,7 +294,7 @@ function return!(mach::Machine{<:Surrogate},
data = Tuple(s.data for s in sources)
[MLJBase.rebind!(s, nothing) for s in sources]

# record the field values
# record the current hyper-parameter values:
old_model = deepcopy(model)

cache = (sources = sources,
Expand All @@ -310,7 +314,7 @@ function (mach::Machine{<:Surrogate})()
"`mach()`, is "*
"deprecated and could lead "*
"to unexpected behaviour for `Composite` models "*
"with fields that are not models. "*
"with hyper-parameters that are not models. "*
"Instead of `fit!(mach, verbosity=verbosity); return mach()` "*
"use `return!(mach, model, verbosity)`, "*
"where `model` is the `Model` instance appearing in your "*
Expand Down
5 changes: 2 additions & 3 deletions src/composition/models/from_network.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
## EXPORTING LEARNING NETWORKS AS MODELS WITH @from_network


# closure to generate the fit methods for exported composite. Here `mach`
function fit_method(mach, models...)

signature = mach.fitresult
mach_args = mach.args

function _fit(model::M, verbosity::Integer, args...) where M
function _fit(model, verbosity::Integer, args...)
length(args) > length(mach_args) &&
throw(ArgumentError("$M does not support more than "*
"$(length(mach_args)) training arguments"))
replacement_models = [getproperty(model, fld)
for fld in fieldnames(M)]
for fld in propertynames(model)]
model_replacements = [models[j] => replacement_models[j]
for j in eachindex(models)]
source_replacements = [mach_args[i] => source(args[i])
Expand Down
18 changes: 9 additions & 9 deletions src/composition/models/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ function update(model::M,
# This method falls back to `fit` to force rebuilding the
# underlying learning network if, since the last fit:
#
# (i) Any field associated with a model in the learning network
# (i) Any hyper-parameter associated with a model in the learning network
# has been replaced with a new model instance (and not merely
# mutated), OR

# (ii) Any OTHER field has changed it's value (in the sense
# (ii) Any OTHER hyper-parameter has changed it's value (in the sense
# of `==`).

# Otherwise, a "smart" fit is carried out by calling `fit!` on a
Expand Down Expand Up @@ -69,16 +69,16 @@ end

# helper for preceding method (where logic is explained):
function fallback(model::M, old_model, network_model_names, glb_node) where M
# check the fields corresponding to models:
# check the hyper-parameters corresponding to models:
network_models = MLJBase.models(glb_node)
for j in eachindex(network_models)
name = network_model_names[j]
name === nothing ||
objectid(network_models[j])===objectid(getproperty(model, name)) ||
return true
end
# check any other fields:
for name in fieldnames(M)
# check any other hyper-parameter:
for name in propertynames(model)
if !(name in network_model_names)
old_value = getproperty(old_model, name)
value = getproperty(model, name)
Expand All @@ -95,13 +95,13 @@ function update(model::Composite,
cache,
args...)

# If any `model` field has been replaced (and not just mutated)
# then we actually need to fit rather than update (which will
# force build of a new learning network). If `model` has been
# If any `model` hyper-parameter has been replaced (and not just
# mutated) then we actually need to fit rather than update (which
# will force build of a new learning network). If `model` has been
# created using a learning network export macro, the test used
# below is perfect. In any other case it is at least conservative:
network_model_ids = objectid.(models(yhat))
fields = [getproperty(model, name) for name in fieldnames(typeof(model))]
fields = [getproperty(model, name) for name in propertynames(model)]
submodels = filter(f->f isa Model, fields)
submodel_ids = objectid.(submodels)
if !issubset(submodel_ids, network_model_ids)
Expand Down
16 changes: 8 additions & 8 deletions src/parameter_inspection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ istransparent(::MLJType) = true
params(m::MLJType)
Recursively convert any transparent object `m` into a named tuple,
keyed on the fields of `m`. An object is *transparent* if
keyed on the property names of `m`. An object is *transparent* if
`MLJBase.istransparent(m) == true`. The named tuple is possibly nested
because `params` is recursively applied to the field values, which
because `params` is recursively applied to the property values, which
themselves might be transparent.
For most `MLJType` objects, properties are synonymous with fields, but
this is not a hard requirement.
Most objects of type `MLJType` are transparent.
julia> params(EnsembleModel(atom=ConstantClassifier()))
Expand All @@ -24,10 +27,7 @@ Most objects of type `MLJType` are transparent.
params(m) = params(m, Val(istransparent(m)))
params(m, ::Val{false}) = m
function params(m, ::Val{true})
fields = fieldnames(typeof(m))
NamedTuple{fields}(Tuple([params(getfield(m, field)) for field in fields]))
fields = propertynames(m)
NamedTuple{fields}(Tuple([params(getproperty(m, field))
for field in fields]))
end




34 changes: 18 additions & 16 deletions src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ end
function Base.show(stream::IO, object::MLJType)
repr = simple_repr(typeof(object))
str = "$repr $(handle(object))"
if !isempty(fieldnames(typeof(object)))
if !isempty(propertynames(object))
printstyled(IOContext(stream, :color=> SHOW_COLOR),
str, bold=false, color=:blue)
else
Expand Down Expand Up @@ -167,14 +167,14 @@ end
fancy(stream::IO, object) = fancy(stream, object, 0,
DEFAULT_AS_CONSTRUCTED_SHOW_DEPTH, 0)
fancy(stream, object, current_depth, depth, n) = show(stream, object)
function fancy(stream, object::M, current_depth, depth, n) where M<:MLJType
function fancy(stream, object::MLJType, current_depth, depth, n)
if current_depth == depth
show(stream, object)
else
prefix = MLJModelInterface.name(object)
anti = max(length(prefix) - INDENT)
print(stream, prefix, "(")
names = fieldnames(M)
names = propertynames(object)
n_names = length(names)
for k in eachindex(names)
value = getproperty(object, names[k])
Expand Down Expand Up @@ -211,7 +211,7 @@ Base.show(object::MLJType, depth::Int) = show(stdout, object, depth)
@more
Entered at the REPL, equivalent to `show(ans, 100)`. Use to get a
recursive description of all fields of the last REPL value.
recursive description of all properties of the last REPL value.
"""
macro more()
Expand All @@ -236,7 +236,7 @@ istoobig(str::AbstractString) = length(str) > 50

## THE `_show` METHOD

# Note: The `_show` method controls how field values are displayed in
# Note: The `_show` method controls how properties are displayed in
# the table generated by `_recursive_show`. See top of file.

# _show fallback:
Expand Down Expand Up @@ -308,13 +308,13 @@ _show(stream::IO, ::Nothing) = println(stream, "nothing")
"""
_recursive_show(stream, object, current_depth, depth)
Generate a table of the field values of the `MLJType` object,
dislaying each value by calling the method `_show` on it. The
behaviour of `_show(stream, f)` is as follows:
Generate a table of the properties of the `MLJType` object, dislaying
each property value by calling the method `_show` on it. The behaviour
of `_show(stream, f)` is as follows:
1. If `f` is itself a `MLJType` object, then its short form is shown
and `_recursive_show` generates as separate table for each of its
field values (and so on, up to a depth of argument `depth`).
properties (and so on, up to a depth of argument `depth`).
2. Otherwise `f` is displayed as "(omitted T)" where `T = typeof(f)`,
unless `istoobig(f)` is false (the `istoobig` fall-back for arbitrary
Expand All @@ -324,10 +324,10 @@ overload the `_show` method for the type in question.
"""
function _recursive_show(stream::IO, object::MLJType, current_depth, depth)
if depth == 0 || isempty(fieldnames(typeof(object)))
if depth == 0 || isempty(propertynames(object))
println(stream, object)
elseif current_depth <= depth
fields = fieldnames(typeof(object))
fields = propertynames(object)
print(stream, "#"^current_depth, " ")
show(stream, object)
println(stream, ": ")
Expand All @@ -337,10 +337,11 @@ function _recursive_show(stream::IO, object::MLJType, current_depth, depth)
return
end
for fld in fields
fld_string = string(fld)*" "^(max(0,COLUMN_WIDTH - length(string(fld))))*"=> "
fld_string = string(fld)*
" "^(max(0,COLUMN_WIDTH - length(string(fld))))*"=> "
print(stream, fld_string)
if isdefined(object, fld)
_show(stream, getfield(object, fld))
_show(stream, getproperty(object, fld))
# println(stream)
else
println(stream, "(undefined)")
Expand All @@ -350,9 +351,10 @@ function _recursive_show(stream::IO, object::MLJType, current_depth, depth)
println(stream)
for fld in fields
if isdefined(object, fld)
subobject = getfield(object, fld)
if isa(subobject, MLJType) && !isempty(fieldnames(typeof(subobject)))
_recursive_show(stream, getfield(object, fld),
subobject = getproperty(object, fld)
if isa(subobject, MLJType) &&
!isempty(propertynames(subobject))
_recursive_show(stream, getproperty(object, fld),
current_depth + 1, depth)
end
end
Expand Down
71 changes: 0 additions & 71 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,77 +7,6 @@ function finaltypes(T::Type)
end
end


# # NOTE: deprecated, see @mlj_model
# """

# @set_defaults ModelType(args...)
# @set_defaults ModelType args

# Create a keyword constructor for any type `ModelType<:MLJBase.Model`,
# using as default values those listed in `args`. These must include a
# value for every field, and in the order appearing in
# `fieldnames(ModelType)`.

# The constructor does not call `MLJBase.clean!(model)` on the
# instantiated object `model`. This method is for internal use only (by
# `@from_network macro`) as it is depreciated by `@mlj_model` macro.

# ### Example

# mutable struct Foo
# x::Int
# y
# end

# @set_defaults Foo(1,2)

# julia> Foo()
# Foo(1, 2)

# julia> Foo(x=1, y="house")
# Foo(1, "house")

# @set_defaults Foo [4, 5]

# julia> Foo()
# Foo(4, 5)

# """
# macro set_defaults(ex)
# T_ex = ex.args[1]
# value_exs = ex.args[2:end]
# values = [__module__.eval(ex) for ex in value_exs]
# set_defaults_(__module__, T_ex, values)
# return nothing
# end

# macro set_defaults(T_ex, values_ex)
# values =__module__.eval(values_ex)
# set_defaults_(__module__, T_ex, values)
# return nothing
# end

# function set_defaults_(mod, T_ex, values)
# T = mod.eval(T_ex)
# fields = fieldnames(T)
# isempty(fields) && return nothing
# length(fields) == length(values) ||
# error("Provide the same number of default values as fields. ")

# equality_pair_exs = [Expr(:kw, fields[i], values[i]) for i in
# eachindex(values)]

# program = quote
# $T_ex(; $(equality_pair_exs...)) =
# $T_ex($(fields...))
# end
# mod.eval(program)

# return nothing
# end


"""
flat_values(t::NamedTuple)
Expand Down
Loading

0 comments on commit 4e857fc

Please sign in to comment.