Skip to content

Commit

Permalink
Merge pull request #397 from alan-turing-institute/dev
Browse files Browse the repository at this point in the history
For a 0.14.8 release
  • Loading branch information
ablaom authored Aug 21, 2020
2 parents 8e25843 + 953a1d2 commit c1cb718
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 87 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.14.7"
version = "0.14.8"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
2 changes: 1 addition & 1 deletion src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ export matrix, int, classes, decoder, table,
nrows, selectrows, selectcols, select

# re-exports from (MLJ)ScientificTypes
export Scientific, Found, Unknown, Known, Finite, Infinite,
export Unknown, Known, Finite, Infinite,
OrderedFactor, Multiclass, Count, Continuous, Textual,
Binary, ColorImage, GrayImage, Image, Table
export scitype, scitype_union, elscitype, nonmissing, trait
Expand Down
10 changes: 8 additions & 2 deletions src/composition/abstract_types.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,35 @@
# true composite models:
abstract type DeterministicComposite <: Deterministic end
abstract type ProbabilisticComposite <: Probabilistic end
abstract type JointProbabilisticComposite <: JointProbabilistic end
abstract type IntervalComposite <: Interval end
abstract type UnsupervisedComposite <: Unsupervised end
abstract type StaticComposite <: Static end

# surrogate composite models:
struct DeterministicSurrogate <: Deterministic end
struct ProbabilisticSurrogate <: Probabilistic end
struct JointProbabilisticSurrogate <: JointProbabilistic end
struct IntervalSurrogate <: Interval end
struct UnsupervisedSurrogate <: Unsupervised end
struct StaticSurrogate <: Static end

Deterministic() = DeterministicSurrogate()
Probabilistic() = ProbabilisticSurrogate()
JointProbabilistic() = JointProbabilisticSurrogate()
Interval() = IntervalSurrogate()
Unsupervised() = UnsupervisedSurrogate()
Static() = StaticSurrogate()

const SupervisedComposite =
Union{DeterministicComposite,ProbabilisticComposite,IntervalComposite}
Union{DeterministicComposite,ProbabilisticComposite,JointProbabilisticComposite,IntervalComposite}

const SupervisedSurrogate =
Union{DeterministicSurrogate,ProbabilisticSurrogate,IntervalSurrogate}
Union{DeterministicSurrogate,ProbabilisticSurrogate,JointProbabilisticSurrogate,IntervalSurrogate}

const Surrogate = Union{DeterministicSurrogate,
ProbabilisticSurrogate,
JointProbabilisticSurrogate,
IntervalSurrogate,
UnsupervisedSurrogate,
StaticSurrogate}
Expand All @@ -35,10 +39,12 @@ const Composite = Union{SupervisedComposite,UnsupervisedComposite,

for T in [:DeterministicComposite,
:ProbabilisticComposite,
:JointProbabilisticComposite,
:IntervalComposite,
:UnsupervisedComposite,
:StaticComposite,
:DeterministicSurrogate,
:JointProbabilisticSurrogate,
:ProbabilisticSurrogate,
:IntervalSurrogate,
:UnsupervisedSurrogate,
Expand Down
42 changes: 26 additions & 16 deletions src/composition/learning_networks/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,24 +103,34 @@ color(N::Node) = (N.machine.frozen ? :red : :green)
# constructor for static operations:
Node(operation, args::AbstractNode...) = Node(operation, nothing, args...)

# make nodes callable:
(y::Node)(; rows=:) =
(y.operation)(y.machine, [arg(rows=rows) for arg in y.args]...)
function (y::Node)(Xnew)
length(y.origins) == 1 ||
error("Node $y has multiple origins and cannot be called "*
_check(y::Node) = nothing
_check(y::Node{Nothing}) = length(y.origins) == 1 ? nothing :
error("Node $y has multiple origins and cannot be called "*
"on new data. ")
return (y.operation)(y.machine, [arg(Xnew) for arg in y.args]...)
end

# and for the special case of static operations:
(y::Node{Nothing})(; rows=:) =
(y.operation)([arg(rows=rows) for arg in y.args]...)
function (y::Node{Nothing})(Xnew)
length(y.origins) == 1 ||
error("Node $y has multiple origins and cannot be called "*
"on new data. ")
return (y.operation)([arg(Xnew) for arg in y.args]...)
# make nodes callable:
(y::Node)(; rows=:) = _apply((y, y.machine); rows=rows)
(y::Node)(Xnew) = (_check(y); _apply((y, y.machine), Xnew))
(y::Node{Nothing})(; rows=:) = _apply((y, ); rows=rows)
(y::Node{Nothing})(Xnew)= (_check(y); _apply((y, ), Xnew))

function _apply(y_plus, input...; kwargs...)
y = y_plus[1]
mach = y_plus[2:end] # in static case this is ()
raw_args = map(y.args) do arg
arg(input...; kwargs...)
end
try
(y.operation)(mach..., raw_args...)
catch exception
@error "Failed "*
"to apply the operation `$(y.operation)` to the machine "*
"$(y.machine), which receives it's data arguments from one or more "*
"nodes in a learning network. Possibly, one of these nodes "*
"is delivering data that is incompatible with the machine's model.\n"*
diagnostics(y, input...; kwargs...)
throw(exception)
end
end

ScientificTypes.elscitype(N::Node) = Unknown
Expand Down
2 changes: 1 addition & 1 deletion src/interface/model_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ predict_median(m, fitresult, Xnew, ::Val{<:BadMedianTypes}) =

# not in MLJModelInterface as methodswith requires InteractiveUtils
MLJModelInterface.implemented_methods(::FI, M::Type{<:MLJType}) =
getfield.(methodswith(M), :name)
getfield.(methodswith(M), :name) |> unique

# serialization fallbacks:
# Here `file` can be `String` or `IO` (eg, `file=IOBuffer()`).
Expand Down
70 changes: 45 additions & 25 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,60 +49,65 @@ end
# makes sense if they are actually source nodes.

function check(model::Supervised, args... ; full=false)

nowarns = true

nargs = length(args)
if nargs == 2
X, y = args
elseif nargs > 2
supports_weights(model) ||
@info("$(typeof(model)) does not support sample weights and " *
"the supplied weights will be ignored in training.\n" *
"However, supplied weights will be passed to " *
"weight-supporting measures on calls to `evaluate!` " *
supports_weights(model) || elscitype(args[3]) <: Unknown ||
@info("$(typeof(model)) does not support sample weights and "*
"the supplied weights will be ignored in training.\n"*
"However, supplied weights will be passed to "*
"weight-supporting measures on calls to `evaluate!` "*
"and in tuning. ")
X, y, w = args
else
throw(ArgumentError("Use `machine(model, X, y)` or " *
"`machine(model, X, y, w)` for a supervised " *
throw(ArgumentError("Use `machine(model, X, y)` or "*
"`machine(model, X, y, w)` for a supervised "*
"model."))
end

if full
# checks on input type:
input_scitype(model) <: Unknown ||
elscitype(X) <: input_scitype(model) ||
@warn "The scitype of `X`, in `machine(model, X, ...)` "
"is incompatible with " *
"`model`:\nscitype(X) = $(elscitype(X))\n" *
"input_scitype(model) = $(input_scitype(model))."
(@warn("The scitype of `X`, in `machine(model, X, ...)` "*
"is incompatible with "*
"`model=$model`:\nscitype(X) = $(elscitype(X))\n"*
"input_scitype(model) = $(input_scitype(model))."); nowarns=false)

# checks on target type:
target_scitype(model) <: Unknown ||
elscitype(y) <: target_scitype(model) ||
@warn "The scitype of `y`, in `machine(model, X, y, ...)` " *
"is incompatible with " *
"`model`:\nscitype(y) = $(elscitype(y))\n" *
"target_scitype(model) = $(target_scitype(model))."
(@warn("The scitype of `y`, in `machine(model, X, y, ...)` "*
"is incompatible with "*
"`model=$model`:\nscitype(y) = $(elscitype(y))\ntarget_scitype(model) "*
"= $(target_scitype(model))."); nowarns=false)

# checks on dimension matching:
X() === nothing || # model fits a distribution to y
nrows(X()) == nrows(y()) ||
throw(DimensionMismatch("Differing number of observations "*
"in input and target. "))
if nargs > 2
w.data isa AbstractVector{<:Real} || w.data === nothing ||
if nargs > 2 && !(w.data isa Nothing)
w.data isa AbstractVector{<:Real} ||
throw(ArgumentError("Weights must be real."))
nrows(w()) == nrows(y()) ||
throw(DimensionMismatch("Weights and target "*
"differ in length."))
end
end
return nothing
return nowarns
end

function check(model::Unsupervised, args...; full=false)
nowarns = true

nargs = length(args)
nargs <= 1 ||
throw(ArgumentError("Wrong number of arguments. Use " *
throw(ArgumentError("Wrong number of arguments. Use "*
"`machine(model, X)` for an unsupervised model, "*
"or `machine(model)` if there are no training "*
"arguments (`Static` transformers).) "))
Expand All @@ -111,11 +116,11 @@ function check(model::Unsupervised, args...; full=false)
# check input scitype
input_scitype(model) <: Unknown ||
elscitype(X) <: input_scitype(model) ||
@warn "The scitype of `X`, in `machine(model, X)` is "*
"incompatible with `model`:\nscitype(X) = $(elscitype(X))\n" *
"input_scitype(model) = $(input_scitype(model))."
(@warn("The scitype of `X`, in `machine(model, X)` is "*
"incompatible with `model=$model`:\nscitype(X) = $(elscitype(X))\n"*
"input_scitype(model) = $(input_scitype(model))."); nowarns=false)
end
return nothing
return nowarns
end


Expand Down Expand Up @@ -414,7 +419,22 @@ function fit_only!(mach::Machine; rows=nothing, verbosity=1, force=false)
# fit the model:
fitlog(mach, :train, verbosity)
mach.fitresult, mach.cache, mach.report =
fit(mach.model, verbosity, raw_args...)
try
fit(mach.model, verbosity, raw_args...)
catch exception
@error "Problem fitting the machine $mach, "*
"possibly because an upstream node in a learning "*
"network is providing data of incompatible scitype. "
_sources = sources(glb(mach.args...))
length(_sources) > 2 ||
mach.model isa Composite ||
all((!isempty).(_sources)) ||
@warn "Some learning network source nodes are empty. "
@info "Running type checks... "
check(mach.model, source.(raw_args)... ; full=true) &&
@info "Type checks okay. "
throw(exception)
end
elseif mach.model != mach.old_model
# update the model:
fitlog(mach, :update, verbosity)
Expand Down Expand Up @@ -448,7 +468,7 @@ end
fit!(mach::Machine, rows=nothing, verbosity=1, force=false)
Fit the machine `mach`. In the case that `mach` has `Node` arguments,
first train all other machines on which `mach` depends.
first train all other machines on which `mach` depends.
To attempt to fit a machine without touching any other machine, use
`fit_only!`. For more on the internal logic of fitting see
Expand Down
39 changes: 13 additions & 26 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
## TODO: need to add checks on the arguments of
## predict(::Machine, ) and transform(::Machine, )

const OPERATIONS = (:predict, :predict_mean, :predict_mode, :predict_median,
const OPERATIONS = (:predict, :predict_mean, :predict_mode, :predict_median, :predict_joint,
:transform, :inverse_transform)

for operation in OPERATIONS
Expand Down Expand Up @@ -98,37 +98,24 @@ end

## SURROGATE AND COMPOSITE MODELS

for operation in [:predict, :transform, :inverse_transform]
for operation in [:predict, :predict_joint, :transform, :inverse_transform]
ex = quote
$operation(model::Union{Composite,Surrogate}, fitresult,X) =
fitresult.$operation(X)
end
eval(ex)
end

function predict_mode(m::Union{ProbabilisticComposite,ProbabilisticSurrogate},
fitresult,
Xnew)
if haskey(fitresult, :predict_mode)
return fitresult.predict_mode(Xnew)
end
return mode.(predict(m, fitresult, Xnew))
end

function predict_mean(m::Union{ProbabilisticComposite,ProbabilisticSurrogate},
fitresult,
Xnew)
if haskey(fitresult, :predict_mean)
return fitresult.predict_mean(Xnew)
end
return mean.(predict(m, fitresult, Xnew))
end

function predict_median(m::Union{ProbabilisticComposite,ProbabilisticSurrogate},
fitresult,
Xnew)
if haskey(fitresult, :predict_median)
return fitresult.predict_median(Xnew)
for (operation, fallback) in [(:predict_mode, :mode), (:predict_mean, :mean), (:predict_median, :median)]
ex = quote
function $(operation)(m::Union{ProbabilisticComposite,ProbabilisticSurrogate},
fitresult,
Xnew)
if haskey(fitresult, $(QuoteNode(operation)))
return fitresult.$(operation)(Xnew)
end
return $(fallback).(predict(m, fitresult, Xnew))
end
end
return median.(predict(m, fitresult, Xnew))
eval(ex)
end
35 changes: 35 additions & 0 deletions src/sources.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,41 @@ function (X::Source)(; rows=:)
end
(X::Source)(Xnew) = Xnew

# return a string of diagnostics for the call `X(input...; kwargs...)`
diagnostic_table_sources(X::AbstractNode) =
"Learning network sources:\n"*
"source\tscitype\n"*
"-------------------------------------------\n"*
reduce(*, ("$s\t$(scitype(s()))\n" for s in sources(X)))

function diagnostics(X::AbstractNode, input...; kwargs...)
raw_args = map(X.args) do arg
arg(input...; kwargs...)
end
_sources = sources(X)
scitypes = scitype.(raw_args)
mach = X.machine
model = mach.model
_input = input_scitype(model)
_target = target_scitype(model)
_output = output_scitype(model)

table1 = "Incoming data:\n"*
"arg of $(X.operation)\tscitype\n"*
"-------------------------------------------\n"*
reduce(*, ("$(X.args[j])\t$(scitypes[j])\n" for j in eachindex(X.args)))

table2 = diagnostic_table_sources(X)
return """
Model ($model):
input_scitype = $_input
target_scitype =$_target
output_scitype =$_output
$table1
$table2"""
end

"""
rebind!(s, X)
Expand Down
8 changes: 4 additions & 4 deletions test/_models/Constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,28 +112,28 @@ metadata_pkg.((ConstantRegressor, ConstantClassifier,
is_wrapper=false)

metadata_model(ConstantRegressor,
input=MLJBase.Table(MLJBase.Scientific),
input=MLJBase.Table,
target=AbstractVector{MLJBase.Continuous},
weights=false,
descr="Constant regressor (Probabilistic).",
path="MLJModels.ConstantRegressor")

metadata_model(DeterministicConstantRegressor,
input=MLJBase.Table(MLJBase.Scientific),
input=MLJBase.Table,
target=AbstractVector{MLJBase.Continuous},
weights=false,
descr="Constant regressor (Deterministic).",
path="MLJModels.DeterministicConstantRegressor")

metadata_model(ConstantClassifier,
input=MLJBase.Table(MLJBase.Scientific),
input=MLJBase.Table,
target=AbstractVector{<:MLJBase.Finite},
weights=true,
descr="Constant classifier (Probabilistic).",
path="MLJModels.ConstantClassifier")

metadata_model(DeterministicConstantClassifier,
input=MLJBase.Table(MLJBase.Scientific),
input=MLJBase.Table,
target=AbstractVector{<:MLJBase.Finite},
weights=false,
descr="Constant classifier (Deterministic).",
Expand Down
Loading

0 comments on commit c1cb718

Please sign in to comment.