Skip to content

Commit

Permalink
Merge pull request #271 from alan-turing-institute/no-data-machine-pr…
Browse files Browse the repository at this point in the history
…edict

Better errors for operations with no data argument
  • Loading branch information
ablaom authored Apr 28, 2020
2 parents a2afd94 + b1d82d0 commit 9e98055
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 5 deletions.
22 changes: 17 additions & 5 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,25 @@ for operation in (:predict, :predict_mean, :predict_mode, :predict_median,
if isdefined(machine, :fitresult) || M <: Static
return $(operation)(machine.model, machine.fitresult, args...)
else
throw(error("$machine has not been trained."))
error("$machine has not been trained.")
end
end
$(operation)(machine::Machine; rows=:) =
$(operation)(machine, selectrows(machine.args[1], rows))
$(operation)(machine::NodalMachine, args::AbstractNode...) =
node($(operation), machine, args...)
function $(operation)(machine::Machine; rows=:)
isempty(machine.args) &&
throw(ArgumentError("Attempt to accesss non-existent data "*
"bound to a machine, "*
"probably because machine was "*
"deserialized. Specify data `X` "*
"with `$($operation)(mach, X)`. "))
return $(operation)(machine, selectrows(machine.args[1], rows))
end
function $(operation)(machine::NodalMachine, args::AbstractNode...)
length(args) > 0 ||
throw(ArgumentError("`args` in `$($operation)(mach, args...)`"*
" cannot be empty if `mach` is a "*
"`NodalMachine`. "))
return node($(operation), machine, args...)
end
end
eval(ex)
end
Expand Down
21 changes: 21 additions & 0 deletions test/operations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module TestOperations

using Test
using MLJBase
using ..Models

@testset "errors for deserialized machines" begin
filename = joinpath(@__DIR__, "machine.jlso")
m = machine(filename)
@test_throws ArgumentError predict(m)
end

@testset "error for operations on nodes" begin
X = source()
m = machine(OneHotEncoder(), X)
@test_throws ArgumentError transform(m)
end

end

true
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ end
VERSION v"1.3.0-" && @test include("composition/arrows.jl")
end

@testset "operations.jl" begin
@test include("operations.jl")
end

@testset "hyperparam" begin
@test include("hyperparam/one_dimensional_ranges.jl")
@test include("hyperparam/one_dimensional_range_methods.jl")
Expand Down

0 comments on commit 9e98055

Please sign in to comment.