Skip to content

Commit

Permalink
add catch for missing target in resampling
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Jul 1, 2024
1 parent 2209563 commit 140fb7b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
16 changes: 15 additions & 1 deletion src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ err_ambiguous_operation(model, measure) = ArgumentError(
"Possible value(s) are: $PREDICT_OPERATIONS_STRING. "
)

ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError(
const ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError(
"""
The `prediction_type` of your model needs to be one of: `:deterministic`,
Expand All @@ -81,6 +81,18 @@ ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError(
"""
)

const ERR_NEED_TARGET = ArgumentError(
"""
To evaluate a model's performance you must provide a target variable `y`, as in
`evaluate(model, X, y; options...)` or
mach = machine(model, X, y)
evaluate!(mach; options...)
"""
)

# ==================================================================
## RESAMPLING STRATEGIES

Expand Down Expand Up @@ -1170,6 +1182,8 @@ function evaluate!(
# weights, measures, operations, and dispatches a
# strategy-specific `evaluate!`

length(mach.args) > 1 || throw(ERR_NEED_TARGET)

repeats > 0 || error("Need `repeats > 0`. ")

if resampling isa TrainTestPairs
Expand Down
6 changes: 6 additions & 0 deletions test/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -948,14 +948,20 @@ end

struct PredictingTransformer <:Unsupervised end
MLJBase.fit(::PredictingTransformer, verbosity, X, y) = (mean(y), nothing, nothing)
MLJBase.fit(::PredictingTransformer, verbosity, X) = (nothing, nothing, nothing)
MLJBase.predict(::PredictingTransformer, fitresult, X) = fill(fitresult, nrows(X))
MLJBase.predict(::PredictingTransformer, ::Nothing, X) = nothing
MLJBase.prediction_type(::Type{<:PredictingTransformer}) = :deterministic

@testset "`Unsupervised` model with a predict" begin
X = rand(10)
y = fill(42.0, 10)
e = evaluate(PredictingTransformer(), X, y, resampling=Holdout(), measure=l2)
@test e.measurement[1] 0
@test_throws(
MLJBase.ERR_NEED_TARGET,
evaluate(PredictingTransformer(), X, measure=l2),
)
end


Expand Down

0 comments on commit 140fb7b

Please sign in to comment.