Skip to content

Commit

Permalink
relax restrictions on model type in resampling
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Jun 30, 2024
1 parent 370b3da commit 2209563
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 12 deletions.
32 changes: 21 additions & 11 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ const ERR_INVALID_OPERATION = ArgumentError(
_ambiguous_operation(model, measure) =
"`$measure` does not support a `model` with "*
"`prediction_type(model) == :$(prediction_type(model))`. "
err_ambiguous_operation(model, measure) = ArgumentError(
_ambiguous_operation(model, measure)*
"\nUnable to infer an appropriate operation for `$measure`. "*
"Explicitly specify `operation=...` or `operations=...`. ")
err_incompatible_prediction_types(model, measure) = ArgumentError(
_ambiguous_operation(model, measure)*
"If your model is truly making probabilistic predictions, try explicitly "*
Expand Down Expand Up @@ -65,11 +61,25 @@ ERR_MEASURES_DETERMINISTIC(measure) = ArgumentError(
"and so is not supported by `$measure`. "*LOG_AVOID
)

# ==================================================================
## MODEL TYPES THAT CAN BE EVALUATED
err_ambiguous_operation(model, measure) = ArgumentError(
_ambiguous_operation(model, measure)*
"\nUnable to infer an appropriate operation for `$measure`. "*
"Explicitly specify `operation=...` or `operations=...`. "*
"Possible value(s) are: $PREDICT_OPERATIONS_STRING. "
)

ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError(
"""
The `prediction_type` of your model needs to be one of: `:deterministic`,
`:probabilistic`, or `:interval`. Does your model implement one of these operations:
$PREDICT_OPERATIONS_STRING? If so, you can try explicitly specifying `operation=...`
or `operations=...` (and consider posting an issue to have the model review it's
definition of `MLJModelInterface.prediction_type`). Otherwise, performance
evaluation is not supported.
# not exported:
const Measurable = Union{Supervised, Annotator}
"""
)

# ==================================================================
## RESAMPLING STRATEGIES
Expand Down Expand Up @@ -987,7 +997,7 @@ function _actual_operations(operation::Nothing,
throw(err_ambiguous_operation(model, m))
end
else
throw(err_ambiguous_operation(model, m))
throw(ERR_UNSUPPORTED_PREDICTION_TYPE)
end
end
end
Expand Down Expand Up @@ -1137,7 +1147,7 @@ See also [`evaluate`](@ref), [`PerformanceEvaluation`](@ref),
"""
function evaluate!(
mach::Machine{<:Measurable};
mach::Machine;
resampling=CV(),
measures=nothing,
measure=measures,
Expand Down Expand Up @@ -1235,7 +1245,7 @@ Returns a [`PerformanceEvaluation`](@ref) object.
See also [`evaluate!`](@ref).
"""
evaluate(model::Measurable, args...; cache=true, kwargs...) =
evaluate(model::Model, args...; cache=true, kwargs...) =
evaluate!(machine(model, args...; cache=cache); kwargs...)

# -------------------------------------------------------------------
Expand Down
26 changes: 25 additions & 1 deletion test/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ end
struct DummyInterval <: Interval end
dummy_interval=DummyInterval()

struct GoofyTransformer <: Unsupervised end

dummy_measure_det(yhat, y) = 42
API.@trait(
typeof(dummy_measure_det),
Expand Down Expand Up @@ -115,6 +117,12 @@ API.@trait(
MLJBase.err_ambiguous_operation(dummy_interval, LogLoss()),
MLJBase._actual_operations(nothing,
[LogLoss(), ], dummy_interval, 1))

# model not have a valid `prediction_type`:
@test_throws(
MLJBase.ERR_UNSUPPORTED_PREDICTION_TYPE,
MLJBase._actual_operations(nothing, [LogLoss(),], GoofyTransformer(), 0),
)
end

@everywhere begin
Expand Down Expand Up @@ -935,7 +943,23 @@ end
end
end

# DUMMY LOGGER

# # TRANSFORMER WITH PREDICT

struct PredictingTransformer <:Unsupervised end
MLJBase.fit(::PredictingTransformer, verbosity, X, y) = (mean(y), nothing, nothing)
MLJBase.predict(::PredictingTransformer, fitresult, X) = fill(fitresult, nrows(X))
MLJBase.prediction_type(::Type{<:PredictingTransformer}) = :deterministic

@testset "`Unsupervised` model with a predict" begin
X = rand(10)
y = fill(42.0, 10)
e = evaluate(PredictingTransformer(), X, y, resampling=Holdout(), measure=l2)
@test e.measurement[1] 0
end


# # DUMMY LOGGER

struct DummyLogger end

Expand Down

0 comments on commit 2209563

Please sign in to comment.