diff --git a/Project.toml b/Project.toml index dd959479..dad9d227 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJBase" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" authors = ["Anthony D. Blaom "] -version = "0.17.0" +version = "0.17.1" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" @@ -39,7 +39,7 @@ InvertedIndices = "^1" JLSO = "^2.1,^2.2" JSON = "^0.21" LossFunctions = "0.5, 0.6" -MLJModelInterface = "^0.3.8,^0.4" +MLJModelInterface = "^0.4" MLJScientificTypes = "^0.4.1" Missings = "^0.4" OrderedCollections = "^1.1" diff --git a/src/MLJBase.jl b/src/MLJBase.jl index ad3bb857..548c0ea6 100644 --- a/src/MLJBase.jl +++ b/src/MLJBase.jl @@ -81,6 +81,7 @@ using StatisticalTraits for trait in StatisticalTraits.TRAITS eval(:(export $trait)) end +export implemented_methods # defined here and not in StatisticalTraits # MLJ model hierarchy export MLJType, Model, Supervised, Unsupervised, @@ -105,15 +106,6 @@ export fit, update, update_data, transform, inverse_transform, predict_mode, predict_mean, predict_median, predict_joint, evaluate, clean! -# model traits -export input_scitype, output_scitype, target_scitype, - is_pure_julia, package_name, package_license, - load_path, package_uuid, package_url, - is_wrapper, supports_weights, supports_online, - docstring, name, is_supervised, - prediction_type, implemented_methods, hyperparameters, - hyperparameter_types, hyperparameter_ranges - # data operations export matrix, int, classes, decoder, table, nrows, selectrows, selectcols, select diff --git a/src/resampling.jl b/src/resampling.jl index 2ce0587e..e0677e04 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -706,6 +706,9 @@ end @static if VERSION >= v"1.3.0-DEV.573" + +_caches_data(::Machine{M, C}) where {M, C} = C # determines if an instantiated machine caches data + function _evaluate!(func, mach, accel::CPUThreads, nfolds, verbosity) nthreads = Threads.nthreads() @@ -735,9 +738,12 @@ function _evaluate!(func, mach, accel::CPUThreads, nfolds, verbosity) end end clean!(mach.model) - #One tmach for each task: - machines = [mach, [machine(mach.model, mach.args...) for - _ in 2:length(partitions)]...] + #One tmach for each task: + machines = vcat(mach, [ + machine(mach.model, mach.args...; cache = _caches_data(mach)) + for _ in 2:length(partitions) + ]) + @sync for (i, parts) in enumerate(partitions) Threads.@spawn begin results[i] = mapreduce(vcat, parts) do k