diff --git a/src/resampling.jl b/src/resampling.jl index dd317092..da8eac72 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -31,10 +31,6 @@ const ERR_INVALID_OPERATION = ArgumentError( _ambiguous_operation(model, measure) = "`$measure` does not support a `model` with "* "`prediction_type(model) == :$(prediction_type(model))`. " -err_ambiguous_operation(model, measure) = ArgumentError( - _ambiguous_operation(model, measure)* - "\nUnable to infer an appropriate operation for `$measure`. "* - "Explicitly specify `operation=...` or `operations=...`. ") err_incompatible_prediction_types(model, measure) = ArgumentError( _ambiguous_operation(model, measure)* "If your model is truly making probabilistic predictions, try explicitly "* @@ -65,11 +61,25 @@ ERR_MEASURES_DETERMINISTIC(measure) = ArgumentError( "and so is not supported by `$measure`. "*LOG_AVOID ) -# ================================================================== -## MODEL TYPES THAT CAN BE EVALUATED +err_ambiguous_operation(model, measure) = ArgumentError( + _ambiguous_operation(model, measure)* + "\nUnable to infer an appropriate operation for `$measure`. "* + "Explicitly specify `operation=...` or `operations=...`. "* + "Possible value(s) are: $PREDICT_OPERATIONS_STRING. " +) + +ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError( + """ + + The `prediction_type` of your model needs to be one of: `:deterministic`, + `:probabilistic`, or `:interval`. Does your model implement one of these operations: + $PREDICT_OPERATIONS_STRING? If so, you can try explicitly specifying `operation=...` + or `operations=...` (and consider posting an issue to have the model review it's + definition of `MLJModelInterface.prediction_type`). Otherwise, performance + evaluation is not supported. -# not exported: -const Measurable = Union{Supervised, Annotator} + """ +) # ================================================================== ## RESAMPLING STRATEGIES @@ -987,7 +997,7 @@ function _actual_operations(operation::Nothing, throw(err_ambiguous_operation(model, m)) end else - throw(err_ambiguous_operation(model, m)) + throw(ERR_UNSUPPORTED_PREDICTION_TYPE) end end end @@ -1137,7 +1147,7 @@ See also [`evaluate`](@ref), [`PerformanceEvaluation`](@ref), """ function evaluate!( - mach::Machine{<:Measurable}; + mach::Machine; resampling=CV(), measures=nothing, measure=measures, @@ -1235,7 +1245,7 @@ Returns a [`PerformanceEvaluation`](@ref) object. See also [`evaluate!`](@ref). """ -evaluate(model::Measurable, args...; cache=true, kwargs...) = +evaluate(model::Model, args...; cache=true, kwargs...) = evaluate!(machine(model, args...; cache=cache); kwargs...) # ------------------------------------------------------------------- diff --git a/test/resampling.jl b/test/resampling.jl index fbf26777..ecfd4d3d 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -25,6 +25,8 @@ end struct DummyInterval <: Interval end dummy_interval=DummyInterval() +struct GoofyTransformer <: Unsupervised end + dummy_measure_det(yhat, y) = 42 API.@trait( typeof(dummy_measure_det), @@ -115,6 +117,12 @@ API.@trait( MLJBase.err_ambiguous_operation(dummy_interval, LogLoss()), MLJBase._actual_operations(nothing, [LogLoss(), ], dummy_interval, 1)) + + # model not have a valid `prediction_type`: + @test_throws( + MLJBase.ERR_UNSUPPORTED_PREDICTION_TYPE, + MLJBase._actual_operations(nothing, [LogLoss(),], GoofyTransformer(), 0), + ) end @everywhere begin @@ -935,7 +943,23 @@ end end end -# DUMMY LOGGER + +# # TRANSFORMER WITH PREDICT + +struct PredictingTransformer <:Unsupervised end +MLJBase.fit(::PredictingTransformer, verbosity, X, y) = (mean(y), nothing, nothing) +MLJBase.predict(::PredictingTransformer, fitresult, X) = fill(fitresult, nrows(X)) +MLJBase.prediction_type(::Type{<:PredictingTransformer}) = :deterministic + +@testset "`Unsupervised` model with a predict" begin + X = rand(10) + y = fill(42.0, 10) + e = evaluate(PredictingTransformer(), X, y, resampling=Holdout(), measure=l2) + @test e.measurement[1] ≈ 0 +end + + +# # DUMMY LOGGER struct DummyLogger end