diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 173e5b30..f3b2d158 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,6 +15,7 @@ jobs: test: name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} runs-on: ${{ matrix.os }} + timeout-minutes: 60 strategy: fail-fast: false matrix: diff --git a/Project.toml b/Project.toml index 98e422ae..a4cf83e5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJBase" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" authors = ["Anthony D. Blaom "] -version = "0.18.25" +version = "0.18.26" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" diff --git a/src/MLJBase.jl b/src/MLJBase.jl index 21293ade..7e28a218 100644 --- a/src/MLJBase.jl +++ b/src/MLJBase.jl @@ -186,7 +186,8 @@ export mav, mae, mape, rms, rmsl, rmslp1, rmsp, l1, l2, log_cosh, RMS, rmse, RootMeanSquaredError, root_mean_squared_error, RootMeanSquaredLogError, RMSL, root_mean_squared_log_error, rmsl, rmsle, RootMeanSquaredLogProportionalError, rmsl1, RMSLP, - MAPE, MeanAbsoluteProportionalError, log_cosh_loss, LogCosh, LogCoshLoss + MAPE, MeanAbsoluteProportionalError, log_cosh_loss, LogCosh, LogCoshLoss, + RSquared, rsq, rsquared # measures/confusion_matrix.jl: export confusion_matrix, confmat, ConfusionMatrix diff --git a/src/composition/models/from_network.jl b/src/composition/models/from_network.jl index ad663c7e..f0689d63 100644 --- a/src/composition/models/from_network.jl +++ b/src/composition/models/from_network.jl @@ -36,7 +36,7 @@ _insert_subtyping(ex, subtype_ex) = # create the exported type symbol, e.g. abstract_type(T) == Unsupervised # would result in :UnsupervisedComposite -_exported_type(T::Model) = Symbol(abstract_type(T), :Composite) +_exported_type(T::Model) = Symbol(nameof(abstract_type(T)), :Composite) function eval_and_reassign(modl, ex) s = gensym() diff --git a/src/measures/README.md b/src/measures/README.md index 991fb303..5a9f3778 100644 --- a/src/measures/README.md +++ b/src/measures/README.md @@ -1,5 +1,9 @@ ## Adding new measures +This document assumes familiarity with the traits provided for +measures. For a summary, query the docstring for +`MLJBase.metadata_measures`. + A measure is ordinarily called on data directly, as in ```julia @@ -15,7 +19,7 @@ julia> m(ŷ, y) 0.019067038457889922 ``` -Recall that to call a measure without performing dimension or pool checks, one +To call a measure without performing dimension or pool checks, one uses `MLJBase.call` instead: ```julia diff --git a/src/measures/continuous.jl b/src/measures/continuous.jl index d1ae819c..e6997378 100644 --- a/src/measures/continuous.jl +++ b/src/measures/continuous.jl @@ -66,6 +66,38 @@ call(::RootMeanSquaredError, w::Arr{<:Real}) = (y .- ŷ).^2 .* w |> skipinvalid |> mean |> sqrt +# ------------------------------------------------------------------------- +# R-squared (coefficient of determination) + +struct RSquared <: Aggregated end + +metadata_measure(RSquared; + instances = ["rsq", "rsquared"], + target_scitype = InfiniteArrMissing, + prediction_type = :deterministic, + orientation = :score, + supports_weights = false) + +const RSQ = RSquared +@create_aliases RSquared + +@create_docs(RSquared, +body= +""" +The R² (also known as R-squared or coefficient of determination) is suitable for interpreting linear regression analysis (Chicco et al., [2021](https://doi.org/10.7717/peerj-cs.623)). + +Let ``\\overline{y}`` denote the mean of ``y``, then + +``\\text{R^2} = 1 - \\frac{∑ (\\hat{y} - y)^2}{∑ \\overline{y} - y)^2}.`` +""") + +function call(::RSquared, ŷ::ArrMissing{<:Real}, y::ArrMissing{<:Real}) + num = (ŷ .- y).^2 |> skipinvalid |> sum + mean_y = mean(y) + denom = (mean_y .- y).^2 |> skipinvalid |> sum + return 1 - (num / denom) +end + # ------------------------------------------------------------------- # LP @@ -180,7 +212,7 @@ metadata_measure(RootMeanSquaredProportionalError; orientation = :loss, aggregation = RootMeanSquare()) - const RMSP = RootMeanSquaredProportionalError +const RMSP = RootMeanSquaredProportionalError @create_aliases RMSP @create_docs(RootMeanSquaredProportionalError, diff --git a/src/measures/measures.jl b/src/measures/measures.jl index d172c085..9f00fde0 100644 --- a/src/measures/measures.jl +++ b/src/measures/measures.jl @@ -110,6 +110,7 @@ function call(measure::Unaggregated, yhat, y, weight_given_class::AbstractDict) return w .* unweighted end + # ## Top level function (measure::Measure)(args...) diff --git a/src/measures/meta_utilities.jl b/src/measures/meta_utilities.jl index e38425b0..29407bdc 100644 --- a/src/measures/meta_utilities.jl +++ b/src/measures/meta_utilities.jl @@ -90,7 +90,73 @@ end """ metadata_measure(T; kw...) -Helper function to write the metadata for a single measure. +Helper function to write the metadata (trait definitions) for a single +measure. + +### Compulsory keyword arguments + +- `target_scitype`: The allowed scientific type of `y` in `measure(ŷ, + y, ...)`. This is typically some abstract array. E.g, in single + target variable regression this is typically + `AbstractArray{<:Union{Missing,Continuous}}`. For a binary + classification metric insensitive to class order, this would + typically be `Union{AbstractArray{<:Union{Missing,Multiclass{2}}}, + AbstractArray{<:Union{Missing,OrderedFactor{2}}}}`, which has the + alias `FiniteArrMissing`. + +- `orientation`: Orientation of the measure. Use `:loss` when lower is + better and `:score` when higher is better. For example, set + `:loss` for root mean square and `:score` for area under the ROC + curve. + +- `prediction_type`: Refers to `ŷ` in `measure(ŷ, y, ...)` and should + be one of: `:deterministic` (`ŷ` has same type as `y`), + `:probabilistic` or `:interval`. + + +#### Optional keyword arguments + +The following have meaningful defaults but may still require +overloading: + +- `instances`: A vector of strings naming the built-in instances of + the measurement type provided by the implementation, which are + usually just common aliases for the default instance. E.g., for + `RSquared` has the `instances = ["rsq", "rsquared"]` which are both + defined as `RSquared()` in the implementation. `MulticlassFScore` + has the `instances = ["macro_f1score", "micro_f1score", + "multiclass_f1score"]`, where `micro_f1score = + MulticlassFScore(average=micro_avg)`, etc. Default is `String[]`. + +- `aggregation`: Aggregation method for measurements, typically + `Mean()` (for, e.g., mean absolute error) or `Sum()` (for number + of true positives). Default is `Mean()`. Must subtype + `StatisticalTraits.AggregationMode`. It is used to: + + - aggregate measurements in resampling (e.g., cross-validation) + + - aggregating per-observation measurements returned by `single` in + the fallback definition of `call` for `Unaggregated` measures + (such as area under the ROC curve). + +- `supports_weights`: Whether the measure can be called with + per-observation weights `w`, as in `l2(ŷ, y, w)`. Default is `true`. + +- `supports_class_weights`: Whether the measure can be called with a + class weight dictionary `w`, as in `micro_f1score(ŷ, y, w)`. Default + is `true`. Default is `false`. + +- `human_name`: Ordinary name of measure. Used in the full + auto-generated docstring, which begins "A measure type for + \$human_name ...". Eg, the `human_name` for `TruePositive` is `number + of true positives. Default is snake-case version of type name, with + underscores replaced by spaces; so `MeanAbsoluteError` becomes "mean + absolute error". + +- `docstring`: An abbreviated docstring, displayed by + `info(measure)`. Fallback uses `human_name` and lists the + `instances`. + """ function metadata_measure(T; name::String="", human_name="", diff --git a/test/measures/continuous.jl b/test/measures/continuous.jl index 5f3a3164..3e645845 100644 --- a/test/measures/continuous.jl +++ b/test/measures/continuous.jl @@ -11,6 +11,7 @@ rng = StableRNG(666899) @test isapprox(mae(yhat, y, w), (1*3 + 2*1 + 4*1 + 3*3)/4) @test isapprox(rms(yhat, y), sqrt(5)) @test isapprox(rms(yhat, y, w), sqrt((1*3^2 + 2*1^2 + 4*1^2 + 3*3^2)/4)) + @test rsq(yhat, y) == -3 @test isapprox(mean(skipinvalid(l1(yhat, y))), 2) @test isapprox(mean(skipinvalid(l1(yhat, y, w))), mae(yhat, y, w)) @test isapprox(mean(skipinvalid(l2(yhat, y))), 5) diff --git a/test/resampling.jl b/test/resampling.jl index 35738ffb..f8c5737f 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -105,7 +105,7 @@ MLJBase.prediction_type(::typeof(dummy_measure_interval)) = :interval end @testset "_feature_dependencies_exist" begin - measures = Any[rms, log_loss, brier_score] + measures = Any[rms, rsq, log_loss, brier_score] @test !MLJBase._feature_dependencies_exist(measures) my_feature_dependent_loss(ŷ, X, y) = sum(abs.(ŷ - y) .* X.penalty)/sum(X.penalty); @@ -359,7 +359,7 @@ end for cache in [true, false] model = Models.DeterministicConstantRegressor() mach = machine(model, X, y, cache=cache) - result = evaluate!(mach, resampling=cv, measure=[rms, rmslp1], + result = evaluate!(mach, resampling=cv, measure=[rms, rsq, rmslp1], acceleration=accel, verbosity=verb) @test result.per_fold[1] ≈ [1/2, 3/4, 1/2, 3/4, 1/2] diff --git a/test/runtests.jl b/test/runtests.jl index 110fbaba..feba4857 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,7 @@ using MLJBase if !MLJBase.TESTING error("To test MLJBase, the environment variable "* "`TEST_MLJBASE` must be set to `\"true\"`\n"* - "You can do this in the REPL with `ENV[\"TEST_MLJBASE\"]=\"true\"") + "You can do this in the REPL with `ENV[\"TEST_MLJBASE\"]=\"true\"`") end @info "nprocs() = $(nprocs())"