From 9742d44241974dc30a9e42ed3b59dd85fd80e27b Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 23 Jan 2024 17:05:36 +1300 Subject: [PATCH] fix Resampler update bug --- src/resampling.jl | 11 +++++++++-- test/resampling.jl | 20 +++++++++++++++++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/resampling.jl b/src/resampling.jl index 85fe9494..a4afc2fa 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -1591,8 +1591,15 @@ function MLJModelInterface.update( mach, e = fitresult train_test_rows = e.train_test_rows - measures = e.measure - operations = e.operation + # since `resampler.model` could have changed, so might the actual measures and + # operations that should be passed to the (low level) `evaluate!`: + measures = _actual_measures(resampler.measure, resampler.model) + operations = _actual_operations( + resampler.operation, + measures, + resampler.model, + verbosity + ) # update the model: mach2 = _update!(mach, resampler.model) diff --git a/test/resampling.jl b/test/resampling.jl index 27850375..e2169ec0 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -1,5 +1,3 @@ -#module TestResampling - using Distributed import ComputationalResources: CPU1, CPUProcesses, CPUThreads using .TestUtilities @@ -876,5 +874,21 @@ end @test contains(printed_evaluations, "N/A") end -#end +@testset_accelerated "issue with Resampler #954" acceleration begin + knn = KNNClassifier() + cnst =DeterministicConstantClassifier() + X, y = make_blobs(10) + + resampler = MLJBase.Resampler( + ;model=knn, + measure=accuracy, + operation=nothing, + acceleration, + ) + mach = machine(resampler, X, y) |> fit! + + resampler.model = cnst + fit!(mach) +end + true