Skip to content

Commit

Permalink
Merge pull request #612 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.18.18 release
  • Loading branch information
ablaom authored Aug 17, 2021
2 parents 0350200 + f7185dd commit c157b5a
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 36 deletions.
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.18.17"
version = "0.18.18"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down Expand Up @@ -33,14 +33,14 @@ ComputationalResources = "0.3"
Distributions = "0.22, 0.23, 0.24, 0.25"
InvertedIndices = "1"
LossFunctions = "0.5, 0.6, 0.7"
MLJModelInterface = "1.1.3"
MLJModelInterface = "1.2"
Missings = "0.4, 1"
OrderedCollections = "1.1"
Parameters = "0.12"
PrettyTables = "1"
ProgressMeter = "1.7.1"
ScientificTypes = "2"
StatisticalTraits = "2"
ScientificTypes = "2.1"
StatisticalTraits = "2.1"
StatsBase = "0.32, 0.33"
Tables = "0.2, 1.0"
julia = "1"
Expand Down
4 changes: 2 additions & 2 deletions src/composition/learning_networks/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ function machine(model::Surrogate, _sources::Source...; pair_itr...)
end

if model isa Supervised
length(_sources) > 1 || _throw_supervised_arg_error()
length(_sources) > 1 || throw(err_supervised_nargs())
elseif model isa Unsupervised
length(_sources) < 2 || _throw_unsupervised_arg_error()
length(_sources) < 2 || throw(err_unsupervised_nargs())
else
throw(DomainError)
end
Expand Down
96 changes: 66 additions & 30 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,53 +74,91 @@ end
# In these checks the args are abstract nodes but `full=true` only
# makes sense if they are actually source nodes.

_throw_supervised_arg_error() = throw(ArgumentError(
err_supervised_nargs() = ArgumentError(
"`Supervised` models should have at least two "*
"training arguments. "*
"Use `machine(model, X, y; ...)` or "*
"`machine(model, X, y, extras...; ...)`. "))
"`machine(model, X, y, extras...; ...)`. ")

_throw_unsupervised_arg_error() = throw(ArgumentError(
err_unsupervised_nargs() = ArgumentError(
"`Unsupervised` models should have one "*
"training argument, except `Static` models, which have none. "*
"Use `machine(model, X; ...)` (usual case) or "*
"`machine(model; ...)` (static case). "))
"`machine(model; ...)` (static case). ")

warn_scitype(model::Supervised, X) =
"The scitype of `X`, in `machine(model, X, ...)` "*
"is incompatible with "*
"`model=$model`:\nscitype(X) = $(elscitype(X))\n"*
"input_scitype(model) = $(input_scitype(model))."

warn_generic_scitype_mismatch(S, F) =
"The scitype of `args` in `machine(model, args...; kwargs)` "*
"does not match the scitype "*
"expected by model's `fit` method.\n"*
" provided: $S\n expected by fit: $F"

warn_scitype(model::Supervised, X, y) =
"The scitype of `y`, in `machine(model, X, y, ...)` "*
"is incompatible with "*
"`model=$model`:\nscitype(y) = "*
"$(elscitype(y))\ntarget_scitype(model) "*
"= $(target_scitype(model))."

warn_scitype(model::Unsupervised, X) =
"The scitype of `X`, in `machine(model, X)` is "*
"incompatible with `model=$model`:\nscitype(X) = $(elscitype(X))\n"*
"input_scitype(model) = $(input_scitype(model))."

err_length_mismatch(model::Supervised) = DimensionMismatch(
"Differing number of observations "*
"in input and target. ")

check(model::Any, args...; kwargs) =
throw(ArgumentError("Expected a `Model` instance, got $model. "))

function check(model::Model, args...; full=false)

nowarns = true

F = fit_data_scitype(model)
F == Unknown && return true

S = Tuple{elscitype.(args)...}
if !(S <: F)
@warn warn_generic_scitype_mismatch(S, F)
nowarns = false
end
end

function check(model::Supervised, args... ; full=false)

nowarns = true

nargs = length(args)
nargs > 1 || _throw_supervised_arg_error()
nargs > 1 || throw(err_supervised_nargs())

full || return nowarns

X, y = args[1:2]

# checks on input type:
input_scitype(model) <: Unknown ||
elscitype(X) <: input_scitype(model) ||
(@warn("The scitype of `X`, in `machine(model, X, ...)` "*
"is incompatible with "*
"`model=$model`:\nscitype(X) = $(elscitype(X))\n"*
"input_scitype(model) = $(input_scitype(model)).");
nowarns=false)
elscitype(X) <: input_scitype(model) || begin
@warn warn_scitype(model, X)
nowarns=false
end

# checks on target type:
target_scitype(model) <: Unknown ||
elscitype(y) <: target_scitype(model) ||
(@warn("The scitype of `y`, in `machine(model, X, y, ...)` "*
"is incompatible with "*
"`model=$model`:\nscitype(y) = "*
"$(elscitype(y))\ntarget_scitype(model) "*
"= $(target_scitype(model)).");
nowarns=false)
elscitype(y) <: target_scitype(model) || begin
@warn warn_scitype(model, X, y)
nowarns=false
end

# checks on dimension matching:

scitype(X) == CallableReturning{Nothing} || nrows(X()) == nrows(y()) ||
throw(DimensionMismatch("Differing number of observations "*
"in input and target. "))
throw(err_length_mismatch(model))

return nowarns

Expand All @@ -130,24 +168,22 @@ function check(model::Unsupervised, args...; full=false)
nowarns = true

nargs = length(args)
nargs <= 1 ||
throw(ArgumentError("Wrong number of arguments. Use "*
"`machine(model, X)` for an unsupervised model, "*
"or `machine(model)` if there are no training "*
"arguments (`Static` transformers).) "))
nargs <= 1 || throw(err_unsupervised_nargs())

if full && nargs == 1
X = args[1]
# check input scitype
input_scitype(model) <: Unknown ||
elscitype(X) <: input_scitype(model) ||
(@warn("The scitype of `X`, in `machine(model, X)` is "*
"incompatible with `model=$model`:\nscitype(X) = $(elscitype(X))\n"*
"input_scitype(model) = $(input_scitype(model))."); nowarns=false)
elscitype(X) <: input_scitype(model) || begin
@warn warn_scitype(model, X)
nowarns=false
end
end
return nowarns
end



"""
machine(model, args...; cache=true)
Expand Down
18 changes: 18 additions & 0 deletions test/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,24 @@ freeze!(stand)
machine(tree, (x=categorical(1:N),), y))
end

struct FooBar <: Model end

MLJBase.fit_data_scitype(::Type{<:FooBar}) =
Union{Tuple{AbstractVector{Count}},
Tuple{AbstractVector{Count},AbstractVector{Continuous}}}

@testset "machine argument check for generic model" begin
X = [1, 2, 3, 4]
y = rand(4)
model = FooBar()
@test_logs machine(model, X, y)
@test_logs machine(model, X)
@test_logs((:warn,
MLJBase.warn_generic_scitype_mismatch(Tuple{scitype(y)},
fit_data_scitype(model))),
machine(model, y))
end

@testset "weights" begin
yraw = ["Perry", "Antonia", "Perry", "Skater"]
X = (x=rand(4),)
Expand Down

0 comments on commit c157b5a

Please sign in to comment.