Skip to content

Commit

Permalink
Merge pull request #417 from alan-turing-institute/dev
Browse files Browse the repository at this point in the history
For a 0.15 release
  • Loading branch information
ablaom authored Aug 28, 2020
2 parents cc09983 + 0804910 commit 0251875
Show file tree
Hide file tree
Showing 34 changed files with 833 additions and 996 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
version: 1.3
- name: Pkg.add("CompatHelper")
run: julia -e 'using Pkg; Pkg.add("CompatHelper")'
- name: CompatHelper.main(; master_branch = "dev")
- name: CompatHelper.main
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: julia -e 'using CompatHelper; CompatHelper.main()'
run: julia -e 'using CompatHelper; CompatHelper.main(; master_branch = "dev")'
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ env:
- JULIA_NUM_THREADS=30
julia:
- 1.0
- 1.4
- 1.5
- nightly
jobs:
allow_failures:
Expand Down
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.14.9"
version = "0.15.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand All @@ -25,7 +25,6 @@ Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Expand All @@ -39,14 +38,13 @@ InvertedIndices = "^1"
JLSO = "^2.1,^2.2"
JSON = "^0.21"
LossFunctions = "0.5, 0.6"
MLJModelInterface = "^0.3.4"
MLJScientificTypes = "^0.1,^0.2"
MLJModelInterface = "^0.3.5"
MLJScientificTypes = "^0.3"
Missings = "^0.4"
OrderedCollections = "^1.1"
Parameters = "^0.12"
PrettyTables = "^0.8,^0.9"
ProgressMeter = "^1.3"
ScientificTypes = "^0.7, 0.8"
StatsBase = "^0.32,^0.33"
Tables = "^0.2,^1.0"
julia = "1"
Expand Down
16 changes: 7 additions & 9 deletions src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ import Base: ==, precision, getindex, setindex!
import Base.+, Base.*

# Scitype
import ScientificTypes: TRAIT_FUNCTION_GIVEN_NAME
import ScientificTypes
import MLJScientificTypes.ScientificTypes: TRAIT_FUNCTION_GIVEN_NAME
using MLJScientificTypes
using MLJModelInterface

Expand Down Expand Up @@ -47,7 +46,7 @@ import StatsBase
import StatsBase: fit!, mode, countmap
import Missings: levels
import Distributions
import Distributions: pdf, sampler
import Distributions: pdf, logpdf, sampler
const Dist = Distributions

# from Standard Library:
Expand All @@ -57,7 +56,7 @@ using Statistics, LinearAlgebra, Random, InteractiveUtils
## EXPORTS

# -------------------------------------------------------------------
# re-exports from MLJModelInterface, (MLJ)ScientificTypes
# re-exports from MLJModelInterface, MLJScientificTypes
# NOTE: MLJBase does **not** re-export UnivariateFinite to avoid
# ambiguities between the raw constructor (MLJBase.UnivariateFinite)
# and the general method (MLJModelInterface.UnivariateFinite)
Expand Down Expand Up @@ -100,7 +99,7 @@ export input_scitype, output_scitype, target_scitype,
export matrix, int, classes, decoder, table,
nrows, selectrows, selectcols, select

# re-exports from (MLJ)ScientificTypes
# re-exports from MLJScientificTypes
export Unknown, Known, Finite, Infinite,
OrderedFactor, Multiclass, Count, Continuous, Textual,
Binary, ColorImage, GrayImage, Image, Table
Expand Down Expand Up @@ -155,8 +154,8 @@ export machine, Machine, fit!, report, fit_only!
export make_blobs, make_moons, make_circles, make_regression

# composition:
export machines, sources, anonymize!, @from_network, fitresults, @pipeline,
glb, @tuple, node, @node, sources, origins,
export machines, sources, anonymize!, @from_network, @pipeline,
glb, @tuple, node, @node, sources, origins, return!,
nrows_at_source, machine,
rebind!, nodes, freeze!, thaw!, models, Node, AbstractNode,
DeterministicSurrogate, ProbabilisticSurrogate, UnsupervisedSurrogate,
Expand Down Expand Up @@ -226,7 +225,7 @@ export TruePositive, TrueNegative, FalsePositive, FalseNegative,
# re-export from Random, StatsBase, Statistics, Distributions,
# CategoricalArrays, InvertedIndices:
export pdf, sampler, mode, median, mean, shuffle!, categorical, shuffle,
levels, levels!, std, Not, support
levels, levels!, std, Not, support, logpdf


# ===================================================================
Expand Down Expand Up @@ -286,7 +285,6 @@ include("composition/models/methods.jl")
include("composition/models/from_network.jl")
include("composition/models/inspection.jl")
include("composition/models/pipelines.jl")
include("composition/models/deprecated.jl")
include("composition/models/_wrapped_function.jl")

include("operations.jl")
Expand Down
206 changes: 176 additions & 30 deletions src/composition/learning_networks/machines.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
## LEARNING NETWORK MACHINES

# ***
surrogate(::Type{<:Deterministic}) = Deterministic()
surrogate(::Type{<:Probabilistic}) = Probabilistic()
surrogate(::Type{<:Unsupervised}) = Unsupervised()
Expand All @@ -25,6 +24,8 @@ If a supertype cannot be deduced, `nothing` is returned.
If the network with given `signature` is not exportable, this method
will not error but it will not a give meaningful return value either.
**Private method.**
"""
function model_supertype(signature)

Expand Down Expand Up @@ -67,16 +68,9 @@ function machine(model::Surrogate, _sources::Source...; pair_itr...)
end

if model isa Supervised
length(_sources) in [2, 3] ||
error("Incorrect number of source nodes specified.\n"*
"Use `machine(model, X, y; ...)` or "*
"`machine(model, X, y, w; ...)` when "*
"`model isa Supervised`. ")
length(_sources) > 1 || _throw_supervised_arg_error()
elseif model isa Unsupervised
length(_sources) == 1 ||
error("Incorrect number of source nodes specified.\n"*
"Use `machine(model, X; ...)` when "*
"`model isa Unsupervised. ` (even if `Static`). ")
length(_sources) < 2 || _throw_unsupervised_arg_error()
else
throw(DomainError)
end
Expand Down Expand Up @@ -113,18 +107,17 @@ function machine(sources::Source...; pair_itr...)
end

"""
anonymize!(sources)
N = glb(mach::Machine{<:Surrogate})
A greatest lower bound for the nodes appearing in the signature of
`mach`.
Returns a named tuple `(sources=..., data=....)` whose values are the
provided source nodes and their contents respectively, and clears the
contents of those source nodes.
**Private method.**
"""
function anonymize!(sources)
data = Tuple(s.data for s in sources)
[MLJBase.rebind!(s, nothing) for s in sources]
return (sources=sources, data=data)
end
glb(mach::Machine{<:Union{Composite,Surrogate}}) =
glb(values(mach.fitresult)...)


"""
fit!(mach::Machine{<:Surrogate};
Expand All @@ -147,27 +140,180 @@ See also [`machine`](@ref)
"""
function fit!(mach::Machine{<:Surrogate}; kwargs...)

signature = mach.fitresult
glb_node = glb(values(signature)...) # greatest lower bound node
glb_node = glb(mach)
fit!(glb_node; kwargs...)

# mach.cache = anonymize!(mach.args)
mach.state += 1
mach.report = report(glb_node)
return mach

end

# make learning network machines callable for use in manual export of
# learning networks:
function (mach::Machine{<:Surrogate})()
# anonymize sources:
mach.cache = anonymize!(mach.args)
return mach.fitresult, mach.cache, mach.report
MLJModelInterface.fitted_params(mach::Machine{<:Surrogate}) =
fitted_params(glb(mach))


## CONSTRUCTING THE RETURN VALUE FOR A COMPOSITE FIT METHOD

# identify which fields of `model` have, as values, a model in the
# learning network wrapped by `mach`, and check that no two such
# fields have have identical values (#377):
function fields_in_network(model::M, mach::Machine{<:Surrogate}) where M<:Model

signature = mach.fitresult
network_model_ids = objectid.(models(glb(mach)))

names = fieldnames(M)

# intialize dict to detect duplicity a la #377:
name_given_id = Dict{UInt64,Vector{Symbol}}()

# identify location of fields whose values are models in the
# learning network, and build name_given_id:
mask = map(names) do name
id = objectid(getproperty(model, name))
is_network_model_field = id in network_model_ids
if is_network_model_field
if haskey(name_given_id, id)
push!(name_given_id[id], name)
else
name_given_id[id] = [name,]
end
end
return is_network_model_field
end |> collect

# perform #377 check:
no_duplicates = all(values(name_given_id)) do name
length(name) == 1
end
if !no_duplicates
for (id, name) in name_given_id
if length(name) > 1
@error "The fields $name of $model have identical model "*
"instances as values. "
end
end
throw(ArgumentError(
"Two distinct fields of a composite model that are both "*
"associated with models in the underlying learning "*
"network (eg, any two fields of a `@pipeline` model) "*
"cannot have identical values, although they can be `==` "*
"(corresponding nested fields are `==`). "*
"Consider constructing instances "*
"separately or use `deepcopy`. "))
end

return names[mask]

end

MLJModelInterface.fitted_params(mach::Machine{<:Surrogate}) =
fitted_params(glb(values(mach.fitresult)...))

"""
return!(mach::Machine{<:Surrogate}, model, verbosity)
The last call in custom code defining the `MLJBase.fit` method for a
new composite model type. Here `model` is the instance of the new type
appearing in the `MLJBase.fit` signature, while `mach` is a learning
network machine constructed using `model`. Not relevant when defining
composite models using `@pipeline` or `@from_network`.
For usage, see the example given below. Specificlly, the call does the
following:
- Determines which fields of `model` point to model instances in the
learning network wrapped by `mach`, for recording in an object
called `cache`, for passing onto the MLJ logic that handles smart
updating (namely, an `MLJBase.update` fallback for composite models).
- Calls `fit!(mach, verbosity=verbosity)`.
- Moves any data in sources nodes of the learning network into `cache`
(for data-anonymization purposes).
- Records a copy of `model` in `cache`.
- Returns `cache` and outcomes of training in an appropriate form
(specifically, `(mach.fitresult, cache, mach.report)`; see [Adding
Models for General
Use](https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/)
for technical details.)
### Example
The following code defines, "by hand", a new model type `MyComposite`
for composing standardization (whitening) with a deterministic
regressor:
```
mutable struct MyComposite <: DeterministicComposite
regressor
end
function MLJBase.fit(model::MyComposite, verbosity, X, y)
Xs = source(X)
ys = source(y)
mach1 = machine(Standardizer(), Xs)
Xwhite = transform(mach1, Xs)
mach2 = machine(model.regressor, Xwhite, ys)
yhat = predict(mach2, Xwhite)
mach = machine(Deterministic(), Xs, ys; predict=yhat)
return!(mach, model, verbosity)
end
```
"""
function return!(mach::Machine{<:Surrogate},
model::Union{Model,Nothing},
verbosity)

network_model_fields = fields_in_network(model, mach)

verbosity isa Nothing || fit!(mach, verbosity=verbosity)

# anonymize the data:
sources = mach.args
data = Tuple(s.data for s in sources)
[MLJBase.rebind!(s, nothing) for s in sources]

# record the field values
old_model = deepcopy(model)

cache = (sources = sources,
data=data,
network_model_fields=network_model_fields,
old_model=old_model)

return mach.fitresult, cache, mach.report

end


#legacy code:
function (mach::Machine{<:Surrogate})()
Base.depwarn("Calling a learning network machine `mach` "*
"with no arguments, as in"*
"`mach()`, is "*
"deprecated and could lead "*
"to unexpected behaviour for `Composite` models "*
"with fields that are not models. "*
"Instead of `fit!(mach, verbosity=verbosity); return mach()` "*
"use `return!(mach, model, verbosity)`, "*
"where `model` is the `Model` instance appearing in your "*
"`MLJBase.fit` signature. Query the `return!` doc-string "*
"for details. ",
nothing)

return!(mach, nothing, nothing)
end
fields_in_network(model::Nothing, mach::Machine{<:Surrogate}) =
nothing


## DUPLICATING AND REPLACING PARTS OF A LEARNING NETWORK MACHINE
Expand Down
6 changes: 3 additions & 3 deletions src/composition/learning_networks/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ function _apply(y_plus, input...; kwargs...)
end
end

ScientificTypes.elscitype(N::Node) = Unknown
function ScientificTypes.elscitype(
MLJScientificTypes.elscitype(N::Node) = Unknown
function MLJScientificTypes.elscitype(
N::Node{<:Machine{<:Union{Deterministic,Unsupervised}}})
if N.operation == MLJBase.predict
return target_scitype(N.machine.model)
Expand All @@ -150,7 +150,7 @@ end
# https://github.com/alan-turing-institute/ScientificTypes.jl/issues/102 :
# Add Probabilistic case to above

ScientificTypes.scitype(N::Node) = CallableReturning{elscitype(N)}
MLJScientificTypes.scitype(N::Node) = CallableReturning{elscitype(N)}


## FITTING A NODE
Expand Down
Loading

0 comments on commit 0251875

Please sign in to comment.