From 246b84d30f8a147ee2a664d7c095aa92a77a86f9 Mon Sep 17 00:00:00 2001 From: aadesh Date: Thu, 30 Apr 2020 12:45:45 +0530 Subject: [PATCH 01/14] adds cross-validation for time series Cross-validation for time series by implementing train_test_pairs and TimeSeriesCV. one can find regarding explanation over. https://robjhyndman.com/hyndsight/tscv/ see #256 --- data/sunspot.csv | 290 +++++++++++++++++++++++++++++++++++++++++++ src/MLJBase.jl | 6 +- src/data/datasets.jl | 7 ++ src/resampling.jl | 62 +++++++-- 4 files changed, 353 insertions(+), 12 deletions(-) create mode 100644 data/sunspot.csv diff --git a/data/sunspot.csv b/data/sunspot.csv new file mode 100644 index 00000000..6222af26 --- /dev/null +++ b/data/sunspot.csv @@ -0,0 +1,290 @@ +sunspot_number +5.0 +11.0 +16.0 +23.0 +36.0 +58.0 +29.0 +20.0 +10.0 +8.0 +3.0 +0.0 +0.0 +2.0 +11.0 +27.0 +47.0 +63.0 +60.0 +39.0 +28.0 +26.0 +22.0 +11.0 +21.0 +40.0 +78.0 +122.0 +103.0 +73.0 +47.0 +35.0 +11.0 +5.0 +16.0 +34.0 +70.0 +81.0 +111.0 +101.0 +73.0 +40.0 +20.0 +16.0 +5.0 +11.0 +22.0 +40.0 +60.0 +80.9 +83.4 +47.7 +47.8 +30.7 +12.2 +9.6 +10.2 +32.4 +47.6 +54.0 +62.9 +85.9 +61.2 +45.1 +36.4 +20.9 +11.4 +37.8 +69.8 +106.1 +100.8 +81.6 +66.5 +34.8 +30.6 +7.0 +19.8 +92.5 +154.4 +125.9 +84.8 +68.1 +38.5 +22.8 +10.2 +24.1 +82.9 +32.0 +130.9 +118.1 +89.9 +66.6 +60.0 +46.9 +41.0 +21.3 +16.0 +6.4 +4.1 +6.8 +14.5 +34.0 +45.0 +43.1 +7.5 +42.2 +28.1 +10.1 +8.1 +2.5 +0.0 +1.4 +5.0 +12.2 +13.9 +35.4 +45.8 +41.1 +30.1 +23.9 +15.6 +6.6 +4.0 +1.8 +8.5 +16.6 +36.3 +49.6 +64.2 +67.0 +70.9 +47.8 +27.5 +8.5 +13.2 +56.9 +121.5 +138.3 +103.2 +85.7 +64.6 +36.7 +24.2 +10.7 +15.0 +40.1 +61.5 +98.5 +124.7 +96.3 +66.6 +64.5 +54.1 +39.0 +20.6 +6.7 +4.3 +22.7 +54.8 +93.8 +95.8 +77.2 +59.1 +44.0 +47.0 +30.5 +16.3 +7.3 +37.6 +74.0 +139.0 +111.2 +101.6 +66.2 +44.7 +17.0 +11.3 +12.4 +3.4 +6.0 +32.3 +54.3 +59.7 +63.7 +63.5 +52.2 +25.4 +13.1 +6.8 +6.3 +7.1 +35.6 +73.0 +85.1 +78.0 +64.0 +41.8 +26.2 +26.7 +12.1 +9.5 +2.7 +5.0 +24.4 +42.0 +63.5 +53.8 +62.0 +48.5 +43.9 +18.6 +5.7 +3.6 +1.4 +9.6 +47.4 +57.1 +103.9 +80.6 +63.6 +37.6 +26.1 +14.2 +5.8 +16.7 +44.3 +63.9 +69.0 +77.8 +64.9 +35.7 +21.2 +11.1 +5.7 +8.7 +36.1 +79.7 +114.4 +109.6 +88.8 +67.8 +47.5 +30.6 +16.3 +9.6 +33.2 +92.6 +151.6 +136.3 +134.7 +83.9 +69.4 +31.5 +13.9 +4.4 +38.0 +141.7 +190.2 +184.8 +159.0 +112.3 +53.9 +37.5 +27.9 +10.2 +15.1 +47.0 +93.8 +105.9 +105.5 +104.5 +66.6 +68.9 +38.0 +34.5 +15.5 +12.6 +27.5 +92.5 +155.4 +154.7 +140.5 +115.9 +66.6 +45.9 +17.9 +13.4 +29.2 +100.2 diff --git a/src/MLJBase.jl b/src/MLJBase.jl index 253f23a4..c6aa4557 100644 --- a/src/MLJBase.jl +++ b/src/MLJBase.jl @@ -130,9 +130,9 @@ export SupervisedTask, UnsupervisedTask, MLJTask, export info_dict # datasets.jl: -export load_boston, load_ames, load_iris, +export load_boston, load_ames, load_iris, load_sunspot, load_reduced_ames, load_crabs, load_smarket, - @load_boston, @load_ames, @load_iris, + @load_boston, @load_ames, @load_iris, @load_reduced_ames, @load_crabs, @load_smarket # machines.jl: @@ -152,7 +152,7 @@ export machines, sources, anonymize!, @from_network, fitresults export @pipeline # resampling.jl: -export ResamplingStrategy, Holdout, CV, StratifiedCV, +export ResamplingStrategy, Holdout, CV, StratifiedCV, TimeSeriesCV, evaluate!, Resampler, PerformanceEvaluation # openml.jl: diff --git a/src/data/datasets.jl b/src/data/datasets.jl index 44ea6fc2..87c24572 100644 --- a/src/data/datasets.jl +++ b/src/data/datasets.jl @@ -154,6 +154,9 @@ const COERCE_SMARKET = ( :Today=>Continuous, :Direction=>Multiclass{2}) +const COERCE_SUNSPOT = ( + (:sunspot_number=>Continuous),) + """ load_dataset(fpath, coercions) @@ -195,6 +198,10 @@ function load_smarket() return merge(data1, (Year=Dates.Date.(data1.Year),)) end +"""Load a well-known sunspots time series with nominal features. +[https://www.sws.bom.gov.au/Educational/2/3/6]](https://www.sws.bom.gov.au/Educational/2/3/6) +""" +load_sunspot() = load_dataset("sunspot.csv", COERCE_SUNSPOT) """Load a well-known public regression dataset with `Continuous` features.""" macro load_boston() diff --git a/src/resampling.jl b/src/resampling.jl index a52b4e58..2ac6cf3c 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -163,6 +163,50 @@ function train_test_pairs(cv::CV, rows) end end +# ---------------------------------------------------------------- +# Cross-validation (TimeSeriesCV) + +""" + TimeSeriesCV = TimeSeriesCV(;nsteps=4) + +Cross-validation resampling strategy, for use in `evaluate!`, +`evaluate` and tuning. + + train_test_pairs(TimeSeriesCV, rows) + +Returns an iterator of `(train, test)` pairs of vectors. (row indices) +The series of `test` sets, each consisting of a single observation & +corresponding `train` set consists only of observations +that occurred prior to the observation that forms the `test` set. +Thus, no future observations can be used in construction. +""" +struct TimeSeriesCV <: ResamplingStrategy + nsteps::Int + function TimeSeriesCV(nsteps) + nsteps >= 1 || error("Must have nsteps >= 1. ") + return new(nsteps) + end +end + +# Constructor with keywords +TimeSeriesCV(;nsteps::Int=4) = + TimeSeriesCV(nsteps) + +function train_test_pairs(TimeSeriesCV::TimeSeriesCV, rows) + + n_obs = length(rows) + nsteps = TimeSeriesCV.nsteps + + n_obs > nsteps || error("Inusufficient data for $nsteps-step cross-validation.\n"* + "Try reducing nsteps. ") + + ret = map(1:n_obs-nsteps) do k + return (rows[1:k], # trainrows + [rows[k + nsteps]]) # testrows + end + return ret +end + # ---------------------------------------------------------------- # Cross-validation (stratified; for `Finite` targets) @@ -550,26 +594,26 @@ evaluate(model::Supervised, args...; kwargs...) = # machines has only one element: function _evaluate!(func, machines, ::CPU1, nfolds, channel, verbosity) - + ret = mapreduce(vcat, 1:nfolds) do k r = func(machines[1], k) - verbosity < 1 || put!(channel, true);yield() + verbosity < 1 || put!(channel, true);yield() r end - + verbosity < 1 || put!(channel, false) return ret end # machines has only one element: function _evaluate!(func, machines, ::CPUProcesses, nfolds, channel, verbosity) - + ret = @distributed vcat for k in 1:nfolds r = func(machines[1], k) verbosity < 1 || put!(channel, true); r end - + verbosity < 1 || put!(channel, false) return ret end @@ -577,10 +621,10 @@ end @static if VERSION >= v"1.3.0-DEV.573" # one machine for each thread; cycle through available threads: function _evaluate!(func, machines, ::CPUThreads, nfolds, channel, verbosity) - + if Threads.nthreads() == 1 return _evaluate!(func, machines, CPU1(), nfolds, channel, verbosity) - end + end tasks= (Threads.@spawn begin id = Threads.threadid() if !haskey(machines, id) @@ -594,7 +638,7 @@ function _evaluate!(func, machines, ::CPUThreads, nfolds, channel, verbosity) for k in 1:nfolds) ret = reduce(vcat, fetch.(tasks)) - + verbosity < 1 || put!(channel, false) return ret end @@ -626,7 +670,7 @@ function evaluate!(mach::Machine, resampling, weights, nfolds = length(resampling) nmeasures = length(measures) - + # For multithreading we need a clone of `mach` for each thread # doing work. These are instantiated as needed except for # threadid=1. From 63df421cb20dd64bc08f00d612b762c4fbd8ee96 Mon Sep 17 00:00:00 2001 From: aadesh Date: Sun, 10 May 2020 01:19:41 +0530 Subject: [PATCH 02/14] requested changes --- src/resampling.jl | 93 +++++++++++++++++++++++++++++----------------- test/resampling.jl | 10 +++++ 2 files changed, 68 insertions(+), 35 deletions(-) diff --git a/src/resampling.jl b/src/resampling.jl index acf560ff..4d9ea398 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -167,44 +167,67 @@ end # ---------------------------------------------------------------- # Cross-validation (TimeSeriesCV) - """ - TimeSeriesCV = TimeSeriesCV(;nsteps=4) +tscv = TimeSeriesCV(;folds=4) Cross-validation resampling strategy, for use in `evaluate!`, -`evaluate` and tuning. +`evaluate` and tuning, when observations are chronological and not +expected to be independent. - train_test_pairs(TimeSeriesCV, rows) +train_test_pairs(tscv, rows) -Returns an iterator of `(train, test)` pairs of vectors. (row indices) -The series of `test` sets, each consisting of a single observation & -corresponding `train` set consists only of observations -that occurred prior to the observation that forms the `test` set. -Thus, no future observations can be used in construction. -""" -struct TimeSeriesCV <: ResamplingStrategy - nsteps::Int - function TimeSeriesCV(nsteps) - nsteps >= 1 || error("Must have nsteps >= 1. ") - return new(nsteps) - end -end +Return an iterator of `(train, test)` pairs of vectors in which `train` progressively +grows in `nfolds` and `test` consists of a `1+nfolds` indexs in `rows`. +Specifically, -# Constructor with keywords -TimeSeriesCV(;nsteps::Int=4) = - TimeSeriesCV(nsteps) +train[k] = rows[1:k] +test[k] = [rows[k + 1, ] +for k in 2:(nfold). -function train_test_pairs(TimeSeriesCV::TimeSeriesCV, rows) +# Examples - n_obs = length(rows) - nsteps = TimeSeriesCV.nsteps - - n_obs > nsteps || error("Inusufficient data for $nsteps-step cross-validation.\n"* - "Try reducing nsteps. ") - - ret = map(1:n_obs-nsteps) do k - return (rows[1:k], # trainrows - [rows[k + nsteps]]) # testrows +```julia-repl +julia> tscv = TimeSeriesCV(nfolds=3) +julia> MLJBase.train_test_pairs(tscv, collect(1:2:15)) +3-element Array{Tuple{Array{Int64,1},Array{Int64,1}},1}: + ([1, 3], [5, 7]) + ([1, 3, 5, 7], [9, 11]) + ([1, 3, 5, 7, 9, 11], [13, 15]) +``` +""" +struct TimeSeriesCV <: ResamplingStrategy + nfolds::Int + function TimeSeriesCV(nfolds) + nfolds > 1 || error("Must have nfolds > 1. ") + return new(nfolds) + end + end + # Constructor with keywords + TimeSeriesCV(;nfolds::Int=4) = + TimeSeriesCV(nfolds) + + function train_test_pairs(tscv::TimeSeriesCV, rows) + if rows != sort(rows) + @warn("TimeSeriesCV being applied to `rows` not in sequence. ") + end + n_obs = length(rows) + nfolds = tscv.nfolds + # number of observations per fold + k = floor(Int, n_obs/nfolds) + k > 0 || error("Inusufficient data for $nfolds-fold cross-validation.\n"* + "Try reducing nfolds. ") + # define the (trainrows, testrows) pairs: + firsts = 1:k:((nfolds)*k + 1) # itr of first `test` rows index + seconds = k:k:((nfolds)*k) + ret = map(2:nfolds+1) do k + f = firsts[k] + if k == nfolds + 1 + s = n_obs + else + s = seconds[k] + end + return (rows[1:f-1], # trainrows + rows[f:s]) # testrows end return ret end @@ -670,11 +693,11 @@ function _evaluate!(func, machines, ::CPUThreads, nfolds,verbosity) desc = "Evaluating over $nfolds folds: ", barglyphs = BarGlyphs("[=> ]"), barlen = 25, - color = :yellow)) - - @sync begin - - @sync for parts in Iterators.partition(1:nfolds, max(1,cld(nfolds, n_threads))) + color = :yellow)) + + @sync begin + + @sync for parts in Iterators.partition(1:nfolds, max(1,cld(nfolds, n_threads))) Threads.@spawn begin for k in parts id = Threads.threadid() diff --git a/test/resampling.jl b/test/resampling.jl index 7ee56d78..69156e7e 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -213,6 +213,16 @@ end @test shuffled.measurement[1] != result.measurement[1] end +@testset "TimeSeriesCV" begin + tscv = TimeSeriesCV(;nfolds=3) + @test MLJBase.train_test_pairs(tscv, collect(1:2:15)) == + [([1, 3], [5, 7]) + ([1, 3, 5, 7], [9, 11]) + ([1, 3, 5, 7, 9, 11], [13, 15])] + @test_logs((:warn, r"TimeSeriesCV being applied to `rows` not in sequence. "), + MLJBase.train_test_pairs(tscv, reverse(1:10))) +end + @testset "stratified_cv" begin # check in explicit example: From a5838389e72143b37c112a23e6416f7234c4cfc7 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Wed, 26 May 2021 16:08:17 +1200 Subject: [PATCH 03/14] add @load_sunspot for consistency --- src/MLJBase.jl | 2 +- src/data/datasets.jl | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/MLJBase.jl b/src/MLJBase.jl index d101ea3b..f7fa8ad5 100644 --- a/src/MLJBase.jl +++ b/src/MLJBase.jl @@ -149,7 +149,7 @@ export average, UnivariateFiniteArray, UnivariateFiniteVector # datasets.jl: export load_boston, load_ames, load_iris, load_sunspot, load_reduced_ames, load_crabs, load_smarket, - @load_boston, @load_ames, @load_iris, + @load_boston, @load_ames, @load_iris, @load_sunspot, @load_reduced_ames, @load_crabs, @load_smarket # sources.jl: diff --git a/src/data/datasets.jl b/src/data/datasets.jl index 1e43078f..8135ba4b 100644 --- a/src/data/datasets.jl +++ b/src/data/datasets.jl @@ -235,6 +235,13 @@ macro load_iris() end end +"""Load a well-known sunspot dataset (single table with one column).""" +macro load_sunspot() + quote + load_sunspot() + end +end + """Load a well-known crab classification dataset with nominal features.""" macro load_crabs() quote From 9b6ebf9af094667c9f606efaee9bed1928d4b5ef Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Wed, 26 May 2021 16:10:36 +1200 Subject: [PATCH 04/14] fix mistake in merge conflict resolution --- src/resampling.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/resampling.jl b/src/resampling.jl index 17f9631f..fd439231 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -1166,11 +1166,6 @@ function MLJModelInterface.fit(resampler::Resampler, verbosity::Int, args...) verbosity, resampler.check_measure) - weights, measures = - _process_weights_measures(resampler.weights, resampler.measure, - mach, resampler.operation, - verbosity, resampler.check_measure) - _acceleration = _process_accel_settings(resampler.acceleration) e = evaluate!(mach, From 6c5c8e00fa985d2f4555f28986e909150d5b8e7d Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Wed, 26 May 2021 16:12:13 +1200 Subject: [PATCH 05/14] fix sunspot doc-string --- src/data/datasets.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/data/datasets.jl b/src/data/datasets.jl index 8135ba4b..42939e02 100644 --- a/src/data/datasets.jl +++ b/src/data/datasets.jl @@ -198,7 +198,7 @@ function load_smarket() return merge(data1, (Year=Dates.Date.(data1.Year),)) end -"""Load a well-known sunspots time series with nominal features. +"""Load a well-known sunspots time series (table with one column). [https://www.sws.bom.gov.au/Educational/2/3/6]](https://www.sws.bom.gov.au/Educational/2/3/6) """ load_sunspot() = load_dataset("sunspot.csv", COERCE_SUNSPOT) @@ -235,7 +235,7 @@ macro load_iris() end end -"""Load a well-known sunspot dataset (single table with one column).""" +"""Load a well-known sunspot timeseries (single table with one column).""" macro load_sunspot() quote load_sunspot() From dd0a9f66cbd479002cd44f3a4410392ed2de5ac5 Mon Sep 17 00:00:00 2001 From: Cameron Bieganek <8310743+CameronBieganek@users.noreply.github.com> Date: Sat, 26 Jun 2021 17:57:58 -0500 Subject: [PATCH 06/14] Update unit tests for TimeSeriesCV. --- test/resampling.jl | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/test/resampling.jl b/test/resampling.jl index c203b167..93974e0d 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -275,13 +275,26 @@ end end @testset "TimeSeriesCV" begin - tscv = TimeSeriesCV(;nfolds=3) - @test MLJBase.train_test_pairs(tscv, collect(1:2:15)) == - [([1, 3], [5, 7]) - ([1, 3, 5, 7], [9, 11]) - ([1, 3, 5, 7, 9, 11], [13, 15])] - @test_logs((:warn, r"TimeSeriesCV being applied to `rows` not in sequence. "), - MLJBase.train_test_pairs(tscv, reverse(1:10))) + tscv = TimeSeriesCV(; nfolds=3) + + pairs = MLJBase.train_test_pairs(tscv, 1:10) + @test pairs = [ + (1:4, [5, 6]), + (1:6, [7, 8]), + (1:8, [9, 10]) + ] + + pairs = MLJBase.train_test_pairs(tscv, 1:2:15) + @test pairs == [ + ([1, 3], [5, 7]) + ([1, 3, 5, 7], [9, 11]) + ([1, 3, 5, 7, 9, 11], [13, 15]) + ] + + @test_logs( + (:warn, "TimeSeriesCV is being applied to `rows` not in sequence."), + MLJBase.train_test_pairs(tscv, reverse(1:10)) + ) end @testset "stratified_cv" begin From f378d8a9ce02ba83e44b7e454ed9d733a40077d8 Mon Sep 17 00:00:00 2001 From: Cameron Bieganek <8310743+CameronBieganek@users.noreply.github.com> Date: Sat, 26 Jun 2021 17:58:25 -0500 Subject: [PATCH 07/14] Update train_test_pairs(::TimeSeriesCV, rows) and change some error()s to ArgumentErrors. --- src/resampling.jl | 103 ++++++++++++++++++++++++--------------------- test/resampling.jl | 11 ++++- 2 files changed, 65 insertions(+), 49 deletions(-) diff --git a/src/resampling.jl b/src/resampling.jl index fd439231..4d1d4b01 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -124,7 +124,7 @@ struct CV <: ResamplingStrategy shuffle::Bool rng::Union{Int,AbstractRNG} function CV(nfolds, shuffle, rng) - nfolds > 1 || error("Must have nfolds > 1. ") + nfolds > 1 || throw(ArgumentError("Must have nfolds > 1. ")) return new(nfolds, shuffle, rng) end end @@ -143,8 +143,13 @@ function train_test_pairs(cv::CV, rows) end n, r = divrem(n_obs, n_folds) - n > 0 || error("Inusufficient data for $n_folds-fold cross-validation.\n"* - "Try reducing nfolds. ") + + if n < 1 + throw(ArgumentError( + """Inusufficient data for $n_folds-fold cross-validation. + Try reducing nfolds. """ + )) + end m = n + 1 # number of observations in first r folds @@ -187,49 +192,50 @@ for k in 2:(nfold). # Examples ```julia-repl -julia> tscv = TimeSeriesCV(nfolds=3) -julia> MLJBase.train_test_pairs(tscv, collect(1:2:15)) -3-element Array{Tuple{Array{Int64,1},Array{Int64,1}},1}: - ([1, 3], [5, 7]) - ([1, 3, 5, 7], [9, 11]) - ([1, 3, 5, 7, 9, 11], [13, 15]) +julia> MLJBase.train_test_pairs(TimeSeriesCV(nfolds=3), 1:10) +3-element Vector{Tuple{UnitRange{Int64}, UnitRange{Int64}}}: + (1:4, 5:6) + (1:6, 7:8) + (1:8, 9:10) ``` """ struct TimeSeriesCV <: ResamplingStrategy - nfolds::Int - function TimeSeriesCV(nfolds) - nfolds > 1 || error("Must have nfolds > 1. ") - return new(nfolds) - end - end - # Constructor with keywords - TimeSeriesCV(;nfolds::Int=4) = - TimeSeriesCV(nfolds) - - function train_test_pairs(tscv::TimeSeriesCV, rows) - if rows != sort(rows) - @warn("TimeSeriesCV being applied to `rows` not in sequence. ") - end - n_obs = length(rows) - nfolds = tscv.nfolds - # number of observations per fold - k = floor(Int, n_obs/nfolds) - k > 0 || error("Inusufficient data for $nfolds-fold cross-validation.\n"* - "Try reducing nfolds. ") - # define the (trainrows, testrows) pairs: - firsts = 1:k:((nfolds)*k + 1) # itr of first `test` rows index - seconds = k:k:((nfolds)*k) - ret = map(2:nfolds+1) do k - f = firsts[k] - if k == nfolds + 1 - s = n_obs - else - s = seconds[k] - end - return (rows[1:f-1], # trainrows - rows[f:s]) # testrows + nfolds::Int + function TimeSeriesCV(nfolds) + nfolds > 0 || throw(ArgumentError("Must have nfolds > 0. ")) + return new(nfolds) + end +end + +# Constructor with keywords +TimeSeriesCV(; nfolds::Int=4) = TimeSeriesCV(nfolds) + +function train_test_pairs(tscv::TimeSeriesCV, rows) + if rows != sort(rows) + @warn "TimeSeriesCV is being applied to `rows` not in sequence. " + end + + n_obs = length(rows) + n_folds = tscv.nfolds + + m, r = divrem(n_obs, n_folds + 1) + + if m < 1 + throw(ArgumentError( + "Inusufficient data for $n_folds-fold " * + "time-series cross-validation.\n" * + "Try reducing nfolds. " + )) + end + + test_folds = Iterators.partition( m+r+1 : n_obs , m) + + return map(test_folds) do test_indices + train_indices = 1 : first(test_indices)-1 + train_rows = rows[train_indices] + test_rows = rows[test_indices] + (train_rows, test_rows) end - return ret end # ---------------------------------------------------------------- @@ -275,7 +281,7 @@ struct StratifiedCV <: ResamplingStrategy shuffle::Bool rng::Union{Int,AbstractRNG} function StratifiedCV(nfolds, shuffle, rng) - nfolds > 1 || error("Must have nfolds > 1. ") + nfolds > 1 || throw(ArgumentError("Must have nfolds > 1. ")) return new(nfolds, shuffle, rng) end end @@ -319,9 +325,12 @@ StratifiedCV(; nfolds::Int=6, shuffle=nothing, rng=nothing) = function train_test_pairs(stratified_cv::StratifiedCV, rows, y) st = scitype(y) - st <: AbstractArray{<:Finite} || - error("Supplied target has scitpye $st but stratified "* - "cross-validation applies only to classification problems. ") + if !(st <: AbstractArray{<:Finite}) + throw(ArgumentError( + "Supplied target has scitpye $st but stratified " * + "cross-validation applies only to classification problems. " + )) + end if stratified_cv.shuffle rows=shuffle!(stratified_cv.rng, collect(rows)) @@ -1072,7 +1081,7 @@ On subsequent calls to `fit!(mach)` new train/test pairs of row indices are only regenerated if `resampling`, `repeats` or `cache` fields of `resampler` have changed. The evolution of an RNG field of `resampler` does *not* constitute a change (`==` for `MLJType` objects -is not sensitive to such changes; see [`is_same_except'](@ref)). +is not sensitive to such changes; see [`is_same_except'](@ref)). If there is single train/test pair, then warm-restart behavior of the wrapped model `resampler.model` will extend to warm-restart behaviour diff --git a/test/resampling.jl b/test/resampling.jl index 93974e0d..a1ad58de 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -70,6 +70,7 @@ end @testset "train test pairs" begin cv = CV(nfolds=5) + pairs = MLJBase.train_test_pairs(cv, 1:24) @test pairs == [ (6:24, 1:5), @@ -78,6 +79,9 @@ end ([1:15..., 21:24...], 16:20), (1:20, 21:24) ] + + # Not enough data for the number of folds. + @test_throws ArgumentError MLJBase.train_test_pairs(cv, 1:4) end @testset "checking measure/model compatibility" begin @@ -292,9 +296,12 @@ end ] @test_logs( - (:warn, "TimeSeriesCV is being applied to `rows` not in sequence."), + (:warn, "TimeSeriesCV is being applied to `rows` not in sequence. "), MLJBase.train_test_pairs(tscv, reverse(1:10)) ) + + # Not enough data for the number of folds. + @test_throws ArgumentError MLJBase.train_test_pairs(TimeSeriesCV(10), 1:8) end @testset "stratified_cv" begin @@ -344,7 +351,7 @@ end @test pairs != pairs_random # wrong target type throws error: - @test_throws(Exception, + @test_throws(ArgumentError, MLJBase.train_test_pairs(scv, rows, CategoricalArrays.unwrap.(y))) From 0c832e87ad35fb14303307dbd7568fe1940f82ef Mon Sep 17 00:00:00 2001 From: Cameron Bieganek <8310743+CameronBieganek@users.noreply.github.com> Date: Sat, 26 Jun 2021 19:44:29 -0500 Subject: [PATCH 08/14] Fix usage of = instead of == in a unit test. --- test/resampling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/resampling.jl b/test/resampling.jl index a1ad58de..3e55309b 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -282,7 +282,7 @@ end tscv = TimeSeriesCV(; nfolds=3) pairs = MLJBase.train_test_pairs(tscv, 1:10) - @test pairs = [ + @test pairs == [ (1:4, [5, 6]), (1:6, [7, 8]), (1:8, [9, 10]) From e8ddd1ec878e03ff10caf5211e6dc2cc1f113069 Mon Sep 17 00:00:00 2001 From: Cameron Bieganek <8310743+CameronBieganek@users.noreply.github.com> Date: Mon, 28 Jun 2021 19:33:53 -0500 Subject: [PATCH 09/14] Simplify the last few lines of the time-series train_test_pairs(). --- src/resampling.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/resampling.jl b/src/resampling.jl index 4d1d4b01..b659ec61 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -232,9 +232,7 @@ function train_test_pairs(tscv::TimeSeriesCV, rows) return map(test_folds) do test_indices train_indices = 1 : first(test_indices)-1 - train_rows = rows[train_indices] - test_rows = rows[test_indices] - (train_rows, test_rows) + rows[train_indices], rows[test_indices] end end From af14e12c33364fb8a95f61c2a2ccc94432640d5b Mon Sep 17 00:00:00 2001 From: Cameron Bieganek <8310743+CameronBieganek@users.noreply.github.com> Date: Mon, 28 Jun 2021 20:34:52 -0500 Subject: [PATCH 10/14] Update the documentation for time series cross-validation. --- src/data/datasets.jl | 4 ++-- src/resampling.jl | 40 +++++++++++++++++++++++++++++++++------- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/src/data/datasets.jl b/src/data/datasets.jl index 42939e02..03f0713b 100644 --- a/src/data/datasets.jl +++ b/src/data/datasets.jl @@ -198,7 +198,7 @@ function load_smarket() return merge(data1, (Year=Dates.Date.(data1.Year),)) end -"""Load a well-known sunspots time series (table with one column). +"""Load a well-known sunspot time series (table with one column). [https://www.sws.bom.gov.au/Educational/2/3/6]](https://www.sws.bom.gov.au/Educational/2/3/6) """ load_sunspot() = load_dataset("sunspot.csv", COERCE_SUNSPOT) @@ -235,7 +235,7 @@ macro load_iris() end end -"""Load a well-known sunspot timeseries (single table with one column).""" +"""Load a well-known sunspot time series (single table with one column).""" macro load_sunspot() quote load_sunspot() diff --git a/src/resampling.jl b/src/resampling.jl index b659ec61..512d3e49 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -173,7 +173,7 @@ end # ---------------------------------------------------------------- # Cross-validation (TimeSeriesCV) """ -tscv = TimeSeriesCV(;folds=4) +tscv = TimeSeriesCV(; nfolds=4) Cross-validation resampling strategy, for use in `evaluate!`, `evaluate` and tuning, when observations are chronological and not @@ -181,13 +181,17 @@ expected to be independent. train_test_pairs(tscv, rows) -Return an iterator of `(train, test)` pairs of vectors in which `train` progressively -grows in `nfolds` and `test` consists of a `1+nfolds` indexs in `rows`. -Specifically, +Returns an `nfolds`-length iterator of `(train, test)` pairs of +vectors (row indices), where each `train` and `test` is a sub-vector +of `rows`. The rows are partitioned sequentially into `nfolds + 1` +approximately equal length partitions, where the first partition is the first +train set, and the second partition is the first test set. The second +train set consists of the first two partitions, and the second test set +consists of the third partition, and so on for each fold. -train[k] = rows[1:k] -test[k] = [rows[k + 1, ] -for k in 2:(nfold). +The first partition (which is the first train set) has length `n + r`, +where `n, r = divrem(length(rows), nfolds + 1)`, and the remaining partitions +(all of the test folds) have length `n`. # Examples @@ -197,6 +201,28 @@ julia> MLJBase.train_test_pairs(TimeSeriesCV(nfolds=3), 1:10) (1:4, 5:6) (1:6, 7:8) (1:8, 9:10) + +julia> model = (@load RidgeRegressor pkg=MultivariateStats verbosity=0)(); + +julia> data = @load_sunspot; + +julia> X = (lag1 = data.sunspot_number[2:end-1], + lag2 = data.sunspot_number[1:end-2]); + +julia> y = data.sunspot_number[3:end]; + +julia> tscv = TimeSeriesCV(nfolds=3); + +julia> evaluate(model, X, y, resampling=tscv, measure=rmse, verbosity=0) +┌───────────────────────────┬───────────────┬────────────────────┐ +│ _.measure │ _.measurement │ _.per_fold │ +├───────────────────────────┼───────────────┼────────────────────┤ +│ RootMeanSquaredError @753 │ 21.7 │ [25.4, 16.3, 22.4] │ +└───────────────────────────┴───────────────┴────────────────────┘ +_.per_observation = [missing] +_.fitted_params_per_fold = [ … ] +_.report_per_fold = [ … ] +_.train_test_rows = [ … ] ``` """ struct TimeSeriesCV <: ResamplingStrategy From 28a5a2ba543e110bffd16104c2edc2ca285cdf89 Mon Sep 17 00:00:00 2001 From: Cameron Bieganek <8310743+CameronBieganek@users.noreply.github.com> Date: Mon, 28 Jun 2021 20:47:32 -0500 Subject: [PATCH 11/14] Change "sunspot" data set name to "sunspots". --- data/{sunspot.csv => sunspots.csv} | 0 src/MLJBase.jl | 4 ++-- src/data/datasets.jl | 8 ++++---- test/data/datasets.jl | 3 +++ 4 files changed, 9 insertions(+), 6 deletions(-) rename data/{sunspot.csv => sunspots.csv} (100%) diff --git a/data/sunspot.csv b/data/sunspots.csv similarity index 100% rename from data/sunspot.csv rename to data/sunspots.csv diff --git a/src/MLJBase.jl b/src/MLJBase.jl index f7fa8ad5..f71ddd6a 100644 --- a/src/MLJBase.jl +++ b/src/MLJBase.jl @@ -147,9 +147,9 @@ export HANDLE_GIVEN_ID, @more, @constant, @bind, color_on, color_off export average, UnivariateFiniteArray, UnivariateFiniteVector # datasets.jl: -export load_boston, load_ames, load_iris, load_sunspot, +export load_boston, load_ames, load_iris, load_sunspots, load_reduced_ames, load_crabs, load_smarket, - @load_boston, @load_ames, @load_iris, @load_sunspot, + @load_boston, @load_ames, @load_iris, @load_sunspots, @load_reduced_ames, @load_crabs, @load_smarket # sources.jl: diff --git a/src/data/datasets.jl b/src/data/datasets.jl index 03f0713b..9e84b75c 100644 --- a/src/data/datasets.jl +++ b/src/data/datasets.jl @@ -154,7 +154,7 @@ const COERCE_SMARKET = ( :Today=>Continuous, :Direction=>Multiclass{2}) -const COERCE_SUNSPOT = ( +const COERCE_SUNSPOTS = ( (:sunspot_number=>Continuous),) """ @@ -201,7 +201,7 @@ end """Load a well-known sunspot time series (table with one column). [https://www.sws.bom.gov.au/Educational/2/3/6]](https://www.sws.bom.gov.au/Educational/2/3/6) """ -load_sunspot() = load_dataset("sunspot.csv", COERCE_SUNSPOT) +load_sunspots() = load_dataset("sunspots.csv", COERCE_SUNSPOTS) """Load a well-known public regression dataset with `Continuous` features.""" macro load_boston() @@ -236,9 +236,9 @@ macro load_iris() end """Load a well-known sunspot time series (single table with one column).""" -macro load_sunspot() +macro load_sunspots() quote - load_sunspot() + load_sunspots() end end diff --git a/test/data/datasets.jl b/test/data/datasets.jl index d23d6e72..fd897fb2 100644 --- a/test/data/datasets.jl +++ b/test/data/datasets.jl @@ -54,5 +54,8 @@ X, y = @load_smarket @test schema(X).names == (:Year, :Lag1, :Lag2, :Lag3, :Lag4, :Lag5, :Volume, :Today) @test scitype(y) == AbstractVector{Multiclass{2}} +X = @load_sunspots +@test schema(X).names = (:sunspot_number, ) + end # module true From 656909edacbd70e9da0308e014fb1b8699411bcb Mon Sep 17 00:00:00 2001 From: Cameron Bieganek <8310743+CameronBieganek@users.noreply.github.com> Date: Mon, 28 Jun 2021 21:00:41 -0500 Subject: [PATCH 12/14] Fix usage of = instead of == again. --- test/data/datasets.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/data/datasets.jl b/test/data/datasets.jl index fd897fb2..e83a11ab 100644 --- a/test/data/datasets.jl +++ b/test/data/datasets.jl @@ -55,7 +55,7 @@ X, y = @load_smarket @test scitype(y) == AbstractVector{Multiclass{2}} X = @load_sunspots -@test schema(X).names = (:sunspot_number, ) +@test schema(X).names == (:sunspot_number, ) end # module true From bed6a5f4ff4bc6ee433e3c849d5e2a4035c5bc00 Mon Sep 17 00:00:00 2001 From: Cameron Bieganek <8310743+CameronBieganek@users.noreply.github.com> Date: Mon, 28 Jun 2021 21:32:08 -0500 Subject: [PATCH 13/14] Fix typo in TimeSeriesCV docstring. --- src/resampling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/resampling.jl b/src/resampling.jl index 512d3e49..872404d0 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -204,7 +204,7 @@ julia> MLJBase.train_test_pairs(TimeSeriesCV(nfolds=3), 1:10) julia> model = (@load RidgeRegressor pkg=MultivariateStats verbosity=0)(); -julia> data = @load_sunspot; +julia> data = @load_sunspots; julia> X = (lag1 = data.sunspot_number[2:end-1], lag2 = data.sunspot_number[1:end-2]); From 8dd104435a01e8a4528690713107e3deb60839ab Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 29 Jun 2021 16:16:24 +1200 Subject: [PATCH 14/14] bump 0.18.13 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 98cde088..fd09c85d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJBase" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" authors = ["Anthony D. Blaom "] -version = "0.18.12" +version = "0.18.13" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"