From 1430251e28de0306bfdc35dca44545a05143417d Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 4 Sep 2022 18:50:57 +1000 Subject: [PATCH 1/2] fix multi-target MLJFlux example --- src/utilities.jl | 2 +- test/models.jl | 1 + test/models/Flux.jl | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 test/models/Flux.jl diff --git a/src/utilities.jl b/src/utilities.jl index 53fa49f..e4030f2 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -6,7 +6,7 @@ function plotting_report(fields, scales, history) n_parameters = length(fields) A = Array{Any}(undef, (n_models, n_parameters)) - measurements = Vector{Float64}(undef, n_models) + measurements = Vector{eltype(first(history).measurement)}(undef, n_models) for j in eachindex(history) entry = history[j] diff --git a/test/models.jl b/test/models.jl index ee48538..f6f3028 100644 --- a/test/models.jl +++ b/test/models.jl @@ -13,6 +13,7 @@ import MLJModelInterface include("models/Constant.jl") include("models/DecisionTree.jl") include("models/NearestNeighbors.jl") +include("models/Flux.jl") include("models/MultivariateStats.jl") include("models/Transformers.jl") include("models/foobarmodel.jl") diff --git a/test/models/Flux.jl b/test/models/Flux.jl new file mode 100644 index 0000000..2640ae5 --- /dev/null +++ b/test/models/Flux.jl @@ -0,0 +1,35 @@ +using MLJFlux, MLJ, Flux + +X = randn(100, 2) +Y = X * rand(2, 2) .+ 0.1 * randn.() +XT = MLJ.table(X, names = [:x1, :x2]) +YT = MLJ.table(Y, names = [:y1, :y2]) +act = tanh +nn = Chain( + Dense(2, 5, act), + Dense(5, 5, act), + Dense(5, 5, act), + Dense(5, 2, identity), +) +builder = MLJFlux.@builder nn + +function multi_target(loss) + (x1, x2) -> sum(map(x1, x2) do _x1, _x2 + loss(_x1, _x2) + end) +end +loss = multi_target(l2) +model = MLJFlux.MultitargetNeuralNetworkRegressor(builder = builder; epochs = 10, loss) +r = (MLJ.range(model, :lambda, lower=1e-6, upper=1.0, scale=:linear), 10) +tuning = MLJ.Grid(shuffle = true) +tuned_model = MLJ.TunedModel( + model; + tuning, + resampling = MLJ.CV(nfolds = 5), + range = [r], + measure = loss, + n = 10, + check_measure = false, +) +mach = MLJ.machine(tuned_model, XT, YT) +MLJ.fit!(mach, verbosity=1) From 935a7a3a8cb026ba16143e3f9824b2d8000f27ba Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 4 Sep 2022 18:55:19 +1000 Subject: [PATCH 2/2] update toml and gitignore --- .gitignore | 4 +++- test/Project.toml | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 7be6e0b..74369b5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,11 @@ /Manifest.toml /docs/Manifest.toml +/test/Manifest.toml .ipynb_checkpoints *~ #* *.bu .DS_Store sandbox/ -docs/build \ No newline at end of file +docs/build +.vscode/ diff --git a/test/Project.toml b/test/Project.toml index 5435657..e0cf91a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,9 +5,11 @@ DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LatinHypercubeSampling = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" @@ -26,7 +28,9 @@ ComputationalResources = "0.3" DecisionTree = "0.10" Distances = "0.10" Distributions = "0.25" +Flux = "0.13" MLJBase = "0.20" +MLJFlux = "0.2" MLJModelInterface = "1.3" MultivariateStats = "0.9" NearestNeighbors = "0.4"