Skip to content

Commit

Permalink
Merge pull request #284 from alan-turing-institute/dev
Browse files Browse the repository at this point in the history
For a 0.10.3 release
  • Loading branch information
ablaom authored May 2, 2020
2 parents aa0086a + eb94be1 commit a0a15b2
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 76 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.13.2"
version = "0.13.3"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
9 changes: 7 additions & 2 deletions src/interface/univariate_finite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@ abstract type NonEuclidean <: Dist.ValueSupport end

const UnivariateFiniteSuper = Dist.Distribution{Dist.Univariate,NonEuclidean}

struct UnivariateFinite{L,U,T<:Real} <: UnivariateFiniteSuper
decoder::CategoricalDecoder{L,U}
# C - original type (eg, Char in `categorical(['a', 'b'])`)
# U - reference type <: Unsigned
# T - raw probability type
# L - subtype of CategoricalValue, eg CategoricalValue{Char,UInt32}

struct UnivariateFinite{C,U,T<:Real} <: UnivariateFiniteSuper
decoder::CategoricalDecoder{C,U}
prob_given_class::LittleDict{U,T}
end

Expand Down
151 changes: 95 additions & 56 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ end
shuffle=nothing,
rng=nothing)
Holdout resampling strategy, for use in `evaluate!`, `evaluate` and in tuning.
Holdout resampling strategy, for use in `evaluate!`, `evaluate` and in
tuning.
train_test_pairs(holdout, rows)
Expand Down Expand Up @@ -104,9 +105,9 @@ of `rows`. The `test` vectors are mutually exclusive and exhaust
`test` vector. With no row pre-shuffling, the order of `rows` is
preserved, in the sense that `rows` coincides precisely with the
concatenation of the `test` vectors, in the order they are
generated. The first `r` test vectors have length `n + 1`, where
`n, r = divrem(length(rows), nfolds)`, and the remaining test vectors
have length `n`.
generated. The first `r` test vectors have length `n + 1`, where `n, r
= divrem(length(rows), nfolds)`, and the remaining test vectors have
length `n`.
Pre-shuffling of `rows` is controlled by `rng` and `shuffle`. If `rng`
is an integer, then the `CV` keyword constructor resets it to
Expand Down Expand Up @@ -163,6 +164,7 @@ function train_test_pairs(cv::CV, rows)
end
end


# ----------------------------------------------------------------
# Cross-validation (stratified; for `Finite` targets)

Expand Down Expand Up @@ -549,54 +551,108 @@ evaluate(model::Supervised, args...; kwargs...) =
# Here `func` is always going to be `get_measurements`; see later

# machines has only one element:
function _evaluate!(func, machines, ::CPU1, nfolds, channel, verbosity)

ret = mapreduce(vcat, 1:nfolds) do k
r = func(machines[1], k)
verbosity < 1 || put!(channel, true);yield()
r
end

verbosity < 1 || put!(channel, false)
function _evaluate!(func, machines, ::CPU1, nfolds, verbosity)
local ret
verbosity < 1 || (p = Progress(nfolds,
dt = 0,
desc = "Evaluating over $nfolds folds: ",
barglyphs = BarGlyphs("[=> ]"),
barlen = 25,
color = :yellow))

ret = mapreduce(vcat, 1:nfolds) do k
r = func(machines[1], k)
verbosity < 1 || begin
p.counter += 1
ProgressMeter.updateProgress!(p)
end
return r
end

return ret
end

# machines has only one element:
function _evaluate!(func, machines, ::CPUProcesses, nfolds, channel, verbosity)
function _evaluate!(func, machines, ::CPUProcesses, nfolds, verbosity) #where T<:AbstractWorkerPool

ret = @distributed vcat for k in 1:nfolds
#verbosity < 1 || update!(p,0)
local ret
@sync begin
channel = RemoteChannel(()->Channel{Bool}(min(1000, nfolds)), 1)
verbosity < 1 || (p = Progress(nfolds,
dt = 0,
desc = "Evaluating over $nfolds folds: ",
barglyphs = BarGlyphs("[=> ]"),
barlen = 25,
color = :yellow))
# printing the progress bar
verbosity < 1 || @async begin
while take!(channel)
next!(p)
end
end


@sync begin
ret = @distributed vcat for k in 1:nfolds
r = func(machines[1], k)
verbosity < 1 || put!(channel, true);
verbosity < 1 || begin
put!(channel, true)
yield()
end
r
end

verbosity < 1 || put!(channel, false)
end
close(channel)
end

return ret
end

@static if VERSION >= v"1.3.0-DEV.573"
# one machine for each thread; cycle through available threads:
function _evaluate!(func, machines, ::CPUThreads, nfolds, channel, verbosity)
function _evaluate!(func, machines, ::CPUThreads, nfolds,verbosity)
n_threads = Threads.nthreads()

if n_threads == 1
return _evaluate!(func, machines, CPU1(), nfolds, verbosity)
end

results = Array{Any, 1}(undef, nfolds)
loc = ReentrantLock()
verbosity < 1 || (p = Progress(nfolds,
dt = 0,
desc = "Evaluating over $nfolds folds: ",
barglyphs = BarGlyphs("[=> ]"),
barlen = 25,
color = :yellow))

if Threads.nthreads() == 1
return _evaluate!(func, machines, CPU1(), nfolds, channel, verbosity)
end
tasks= (Threads.@spawn begin
@sync begin

@sync for parts in Iterators.partition(1:nfolds, max(1,floor(Int, nfolds/n_threads)))
Threads.@spawn begin
for k in parts
id = Threads.threadid()
if !haskey(machines, id)
machines[id] =
machine(machines[1].model, machines[1].args...)
machines[id] =
machine(machines[1].model, machines[1].args...)
end
results[k] = func(machines[id], k)
verbosity < 1 || (begin
lock(loc)do
p.counter +=1
ProgressMeter.updateProgress!(p)
end

end)
end
r = func(machines[id], k)
verbosity < 1 || put!(channel, true); yield()
r
end
for k in 1:nfolds)
end

ret = reduce(vcat, fetch.(tasks))

verbosity < 1 || put!(channel, false)
return ret
end

return reduce(vcat, results)
end

end
Expand Down Expand Up @@ -632,16 +688,6 @@ function evaluate!(mach::Machine, resampling, weights,
# threadid=1.
machines = Dict(1 => mach)

# set up progress meter and a remote channel for communication
verbosity < 1 || (p = Progress(nfolds,
dt = 0,
desc = "Evaluating over $nfolds folds: ",
barglyphs = BarGlyphs("[=> ]"),
barlen = 25,
color = :yellow))

channel = acceleration isa CPU1 ? RemoteChannel(()->Channel{Bool}(1) , 1) : RemoteChannel(()->Channel{Bool}(min(1000, nfolds)), 1)

function get_measurements(mach, k)
train, test = resampling[k]
fit!(mach; rows=train, verbosity=verbosity-1, force=force)
Expand Down Expand Up @@ -669,23 +715,14 @@ function evaluate!(mach::Machine, resampling, weights,
"using $(Threads.nthreads()) threads."
end
end

@sync begin
# printing the progress bar
verbosity < 1 || @async while take!(channel)
next!(p)
end

global measurements_flat =

measurements_flat =
_evaluate!(get_measurements,
machines,
acceleration,
nfolds,
channel, verbosity)
end

verbosity)

close(channel)

# in the following rows=folds, columns=measures:
measurements_matrix = permutedims(
Expand Down Expand Up @@ -721,11 +758,12 @@ function evaluate!(mach::Machine, resampling, weights,
measurement=per_measure,
per_fold=per_fold,
per_observation=per_observation)

return ret

end


# ----------------------------------------------------------------
# Evaluation when `resampling` is a ResamplingStrategy

Expand Down Expand Up @@ -831,12 +869,12 @@ function MLJBase.fit(resampler::Resampler, verbosity::Int, args...)
_process_weights_measures(resampler.weights, resampler.measure,
mach, resampler.operation,
verbosity, resampler.check_measure)


fitresult = evaluate!(mach, resampler.resampling,
weights, nothing, verbosity - 1, resampler.repeats,
measures, resampler.operation,
resampler.acceleration, false)

cache = (mach, deepcopy(resampler.resampling))
report = NamedTuple()

Expand Down Expand Up @@ -898,3 +936,4 @@ function evaluate(machine::AbstractMachine{<:Resampler})
throw(error("$machine has not been trained."))
end
end

18 changes: 1 addition & 17 deletions test/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,12 @@ end
end

machines = Dict(1 => machine(ConstantRegressor(), X, y))

channel = RemoteChannel(()->Channel{Bool}(nfolds) , 1)
p = Progress(nfolds, dt=0)

@sync begin

# printing the progress bar
t1 = @async while take!(channel)
next!(p)
end

t2 = @async begin
global result =
MLJBase._evaluate!(func, machines, accel, nfolds, channel, 1)
end
end
result = MLJBase._evaluate!(func, machines, accel, nfolds, 1)

@test result ==
[1:1, 1:1, 1:2, 1:2, 1:3, 1:3, 1:4, 1:4, 1:5, 1:5, 1:6, 1:6]

close(channel)

end


Expand Down

0 comments on commit a0a15b2

Please sign in to comment.