From 8e40890118a68f1e58563ccbe60a042662b1afcc Mon Sep 17 00:00:00 2001 From: Okon Samuel <39421418+OkonSamuel@users.noreply.github.com> Date: Tue, 26 Jan 2021 14:05:20 +0100 Subject: [PATCH 1/6] add cache option for multi-threaded evaluation --- src/resampling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/resampling.jl b/src/resampling.jl index 2ce0587e..7339873a 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -736,7 +736,7 @@ function _evaluate!(func, mach, accel::CPUThreads, nfolds, verbosity) end clean!(mach.model) #One tmach for each task: - machines = [mach, [machine(mach.model, mach.args...) for + machines = [mach, [machine(mach.model, mach.args...;mach.cache) for _ in 2:length(partitions)]...] @sync for (i, parts) in enumerate(partitions) Threads.@spawn begin From f749cdb0636f67c91a513d5445550f1311000867 Mon Sep 17 00:00:00 2001 From: OkonSamuel Date: Tue, 26 Jan 2021 15:07:12 +0100 Subject: [PATCH 2/6] add cache option for multi-threaded evaluation --- src/resampling.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/resampling.jl b/src/resampling.jl index 7339873a..44948740 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -735,9 +735,12 @@ function _evaluate!(func, mach, accel::CPUThreads, nfolds, verbosity) end end clean!(mach.model) - #One tmach for each task: - machines = [mach, [machine(mach.model, mach.args...;mach.cache) for - _ in 2:length(partitions)]...] + #One tmach for each task: + machines = vcat(mach, [ + machine(mach.model, mach.args...; cache = mach.cache) + for _ in 2:length(partitions) + ]) + @sync for (i, parts) in enumerate(partitions) Threads.@spawn begin results[i] = mapreduce(vcat, parts) do k From 8d1788982659219d1271a6b5f9bbb387021e8faf Mon Sep 17 00:00:00 2001 From: Okon Samuel <39421418+OkonSamuel@users.noreply.github.com> Date: Tue, 26 Jan 2021 16:53:17 +0100 Subject: [PATCH 3/6] add _caches_data method to fix failing tests --- src/resampling.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/resampling.jl b/src/resampling.jl index 44948740..37ade4ee 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -706,6 +706,9 @@ end @static if VERSION >= v"1.3.0-DEV.573" + +_caches_data(::Machine{M, C}) = C # determines if an instantiated machine caches data + function _evaluate!(func, mach, accel::CPUThreads, nfolds, verbosity) nthreads = Threads.nthreads() @@ -737,7 +740,7 @@ function _evaluate!(func, mach, accel::CPUThreads, nfolds, verbosity) clean!(mach.model) #One tmach for each task: machines = vcat(mach, [ - machine(mach.model, mach.args...; cache = mach.cache) + machine(mach.model, mach.args...; cache = _caches_data(mach)) for _ in 2:length(partitions) ]) From 455c7be4a908152c6479e60ef3aead34290f34a2 Mon Sep 17 00:00:00 2001 From: Okon Samuel <39421418+OkonSamuel@users.noreply.github.com> Date: Tue, 26 Jan 2021 16:59:38 +0100 Subject: [PATCH 4/6] typo fix --- src/resampling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/resampling.jl b/src/resampling.jl index 37ade4ee..e0677e04 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -707,7 +707,7 @@ end @static if VERSION >= v"1.3.0-DEV.573" -_caches_data(::Machine{M, C}) = C # determines if an instantiated machine caches data +_caches_data(::Machine{M, C}) where {M, C} = C # determines if an instantiated machine caches data function _evaluate!(func, mach, accel::CPUThreads, nfolds, verbosity) From 2685c409a0b329a8698faeb6e10189d7e621126b Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 29 Jan 2021 17:38:40 +1300 Subject: [PATCH 5/6] bump [compat] MLJModelInterface = "^0.4"; rm redundant exports --- Project.toml | 2 +- src/MLJBase.jl | 10 +--------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index dd959479..8cb7a024 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ InvertedIndices = "^1" JLSO = "^2.1,^2.2" JSON = "^0.21" LossFunctions = "0.5, 0.6" -MLJModelInterface = "^0.3.8,^0.4" +MLJModelInterface = "^0.4" MLJScientificTypes = "^0.4.1" Missings = "^0.4" OrderedCollections = "^1.1" diff --git a/src/MLJBase.jl b/src/MLJBase.jl index ad3bb857..548c0ea6 100644 --- a/src/MLJBase.jl +++ b/src/MLJBase.jl @@ -81,6 +81,7 @@ using StatisticalTraits for trait in StatisticalTraits.TRAITS eval(:(export $trait)) end +export implemented_methods # defined here and not in StatisticalTraits # MLJ model hierarchy export MLJType, Model, Supervised, Unsupervised, @@ -105,15 +106,6 @@ export fit, update, update_data, transform, inverse_transform, predict_mode, predict_mean, predict_median, predict_joint, evaluate, clean! -# model traits -export input_scitype, output_scitype, target_scitype, - is_pure_julia, package_name, package_license, - load_path, package_uuid, package_url, - is_wrapper, supports_weights, supports_online, - docstring, name, is_supervised, - prediction_type, implemented_methods, hyperparameters, - hyperparameter_types, hyperparameter_ranges - # data operations export matrix, int, classes, decoder, table, nrows, selectrows, selectcols, select From a086dadcc7283b4a6b6364205f70342e9e8ffdeb Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 29 Jan 2021 17:41:49 +1300 Subject: [PATCH 6/6] bump 0.17.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8cb7a024..dad9d227 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJBase" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" authors = ["Anthony D. Blaom "] -version = "0.17.0" +version = "0.17.1" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"