diff --git a/Project.toml b/Project.toml index 5a22535b..937301ec 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJBase" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" authors = ["Anthony D. Blaom "] -version = "0.13.1" +version = "0.13.2" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" diff --git a/src/composition/networks.jl b/src/composition/networks.jl index 3077bc2c..f7312294 100644 --- a/src/composition/networks.jl +++ b/src/composition/networks.jl @@ -555,7 +555,7 @@ function models(W::AbstractNode) end """ - sources(W::AbstractNode; kind=:any) + sources(N::AbstractNode; kind=:any) A vector of all sources referenced by calls `N()` and `fit!(N)`. These are the sources of the directed acyclic graph associated with the diff --git a/src/operations.jl b/src/operations.jl index b13822ff..6a0c66b1 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -28,13 +28,25 @@ for operation in (:predict, :predict_mean, :predict_mode, :predict_median, if isdefined(machine, :fitresult) || M <: Static return $(operation)(machine.model, machine.fitresult, args...) else - throw(error("$machine has not been trained.")) + error("$machine has not been trained.") end end - $(operation)(machine::Machine; rows=:) = - $(operation)(machine, selectrows(machine.args[1], rows)) - $(operation)(machine::NodalMachine, args::AbstractNode...) = - node($(operation), machine, args...) + function $(operation)(machine::Machine; rows=:) + isempty(machine.args) && + throw(ArgumentError("Attempt to accesss non-existent data "* + "bound to a machine, "* + "probably because machine was "* + "deserialized. Specify data `X` "* + "with `$($operation)(mach, X)`. ")) + return $(operation)(machine, selectrows(machine.args[1], rows)) + end + function $(operation)(machine::NodalMachine, args::AbstractNode...) + length(args) > 0 || + throw(ArgumentError("`args` in `$($operation)(mach, args...)`"* + " cannot be empty if `mach` is a "* + "`NodalMachine`. ")) + return node($(operation), machine, args...) + end end eval(ex) end diff --git a/test/composition/arrows.jl b/test/composition/arrows.jl index 93f06924..7c83f399 100644 --- a/test/composition/arrows.jl +++ b/test/composition/arrows.jl @@ -31,12 +31,12 @@ using Random fit!(ŷ, rows=train) - @test isapprox(rms(ŷ(rows=test), ys(rows=test)), 0.627123, rtol=1e-4) + @test isapprox(rms(ŷ(rows=test), ys(rows=test)), 0.627123, atol=0.07) # shortcut to get and set hyperparameters of a node ẑ[:lambda] = 5.0 fit!(ŷ, rows=train) - @test isapprox(rms(ŷ(rows=test), ys(rows=test)), 0.62699, rtol=1e-4) + @test isapprox(rms(ŷ(rows=test), ys(rows=test)), 0.62699, atol=0.07) end @testset "Auto-source" begin diff --git a/test/operations.jl b/test/operations.jl new file mode 100644 index 00000000..4e61dbf8 --- /dev/null +++ b/test/operations.jl @@ -0,0 +1,21 @@ +module TestOperations + +using Test +using MLJBase +using ..Models + +@testset "errors for deserialized machines" begin + filename = joinpath(@__DIR__, "machine.jlso") + m = machine(filename) + @test_throws ArgumentError predict(m) +end + +@testset "error for operations on nodes" begin + X = source() + m = machine(OneHotEncoder(), X) + @test_throws ArgumentError transform(m) +end + +end + +true diff --git a/test/runtests.jl b/test/runtests.jl index 828e56c9..5e49bcc8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -71,6 +71,10 @@ end VERSION ≥ v"1.3.0-" && @test include("composition/arrows.jl") end +@testset "operations.jl" begin + @test include("operations.jl") +end + @testset "hyperparam" begin @test include("hyperparam/one_dimensional_ranges.jl") @test include("hyperparam/one_dimensional_range_methods.jl")