Skip to content

Commit

Permalink
Merge pull request #510 from alan-turing-institute/dev
Browse files Browse the repository at this point in the history
For a 0.17.1 release
  • Loading branch information
ablaom authored Jan 29, 2021
2 parents 4ccfa15 + 19d7fcf commit 11ef61d
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 14 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.17.0"
version = "0.17.1"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down Expand Up @@ -39,7 +39,7 @@ InvertedIndices = "^1"
JLSO = "^2.1,^2.2"
JSON = "^0.21"
LossFunctions = "0.5, 0.6"
MLJModelInterface = "^0.3.8,^0.4"
MLJModelInterface = "^0.4"
MLJScientificTypes = "^0.4.1"
Missings = "^0.4"
OrderedCollections = "^1.1"
Expand Down
10 changes: 1 addition & 9 deletions src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ using StatisticalTraits
for trait in StatisticalTraits.TRAITS
eval(:(export $trait))
end
export implemented_methods # defined here and not in StatisticalTraits

# MLJ model hierarchy
export MLJType, Model, Supervised, Unsupervised,
Expand All @@ -105,15 +106,6 @@ export fit, update, update_data, transform, inverse_transform,
predict_mode, predict_mean, predict_median, predict_joint,
evaluate, clean!

# model traits
export input_scitype, output_scitype, target_scitype,
is_pure_julia, package_name, package_license,
load_path, package_uuid, package_url,
is_wrapper, supports_weights, supports_online,
docstring, name, is_supervised,
prediction_type, implemented_methods, hyperparameters,
hyperparameter_types, hyperparameter_ranges

# data operations
export matrix, int, classes, decoder, table,
nrows, selectrows, selectcols, select
Expand Down
12 changes: 9 additions & 3 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,9 @@ end

@static if VERSION >= v"1.3.0-DEV.573"


_caches_data(::Machine{M, C}) where {M, C} = C # determines if an instantiated machine caches data

function _evaluate!(func, mach, accel::CPUThreads, nfolds, verbosity)

nthreads = Threads.nthreads()
Expand Down Expand Up @@ -735,9 +738,12 @@ function _evaluate!(func, mach, accel::CPUThreads, nfolds, verbosity)
end
end
clean!(mach.model)
#One tmach for each task:
machines = [mach, [machine(mach.model, mach.args...) for
_ in 2:length(partitions)]...]
#One tmach for each task:
machines = vcat(mach, [
machine(mach.model, mach.args...; cache = _caches_data(mach))
for _ in 2:length(partitions)
])

@sync for (i, parts) in enumerate(partitions)
Threads.@spawn begin
results[i] = mapreduce(vcat, parts) do k
Expand Down

0 comments on commit 11ef61d

Please sign in to comment.