From 140fb7bcf3f46d48120c6a35835c96992da6bbac Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 1 Jul 2024 13:37:46 +1200 Subject: [PATCH] add catch for missing target in resampling --- src/resampling.jl | 16 +++++++++++++++- test/resampling.jl | 6 ++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/resampling.jl b/src/resampling.jl index da8eac72..0c4ffda9 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -68,7 +68,7 @@ err_ambiguous_operation(model, measure) = ArgumentError( "Possible value(s) are: $PREDICT_OPERATIONS_STRING. " ) -ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError( +const ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError( """ The `prediction_type` of your model needs to be one of: `:deterministic`, @@ -81,6 +81,18 @@ ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError( """ ) +const ERR_NEED_TARGET = ArgumentError( + """ + + To evaluate a model's performance you must provide a target variable `y`, as in + `evaluate(model, X, y; options...)` or + + mach = machine(model, X, y) + evaluate!(mach; options...) + + """ +) + # ================================================================== ## RESAMPLING STRATEGIES @@ -1170,6 +1182,8 @@ function evaluate!( # weights, measures, operations, and dispatches a # strategy-specific `evaluate!` + length(mach.args) > 1 || throw(ERR_NEED_TARGET) + repeats > 0 || error("Need `repeats > 0`. ") if resampling isa TrainTestPairs diff --git a/test/resampling.jl b/test/resampling.jl index ecfd4d3d..a91c28c5 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -948,7 +948,9 @@ end struct PredictingTransformer <:Unsupervised end MLJBase.fit(::PredictingTransformer, verbosity, X, y) = (mean(y), nothing, nothing) +MLJBase.fit(::PredictingTransformer, verbosity, X) = (nothing, nothing, nothing) MLJBase.predict(::PredictingTransformer, fitresult, X) = fill(fitresult, nrows(X)) +MLJBase.predict(::PredictingTransformer, ::Nothing, X) = nothing MLJBase.prediction_type(::Type{<:PredictingTransformer}) = :deterministic @testset "`Unsupervised` model with a predict" begin @@ -956,6 +958,10 @@ MLJBase.prediction_type(::Type{<:PredictingTransformer}) = :deterministic y = fill(42.0, 10) e = evaluate(PredictingTransformer(), X, y, resampling=Holdout(), measure=l2) @test e.measurement[1] ≈ 0 + @test_throws( + MLJBase.ERR_NEED_TARGET, + evaluate(PredictingTransformer(), X, measure=l2), + ) end