From 93b5623bacb275f661e0e03b1e0ed53c50c8ca90 Mon Sep 17 00:00:00 2001 From: samyhaff Date: Sat, 8 Jun 2024 18:07:43 +0200 Subject: [PATCH] added support for visualizing cmaes output --- configs/networks/pole_balancing_5.toml | 33 ++++++++++++++++++++++ configs/networks/pole_balancing_6.toml | 38 ++++++++++++++++++++++++++ src/gui.rs | 16 +++++++++-- src/neural_network.rs | 2 +- 4 files changed, 85 insertions(+), 4 deletions(-) create mode 100644 configs/networks/pole_balancing_5.toml create mode 100644 configs/networks/pole_balancing_6.toml diff --git a/configs/networks/pole_balancing_5.toml b/configs/networks/pole_balancing_5.toml new file mode 100644 index 0000000..cf60b27 --- /dev/null +++ b/configs/networks/pole_balancing_5.toml @@ -0,0 +1,33 @@ +input_ids = [1, 2, 3, 4, 5, 6] +output_ids = [13] +bias_id = 7 + +[[neurons]] +id = 8 +inputs = [1, 2, 3, 4, 5, 6, 7] +activation = "sigmoid" + +[[neurons]] +id = 9 +inputs = [1, 2, 3, 4, 5, 6, 7] +activation = "sigmoid" + +[[neurons]] +id = 10 +inputs = [1, 2, 3, 4, 5, 6, 7] +activation = "sigmoid" + +[[neurons]] +id = 11 +inputs = [1, 2, 3, 4, 5, 6, 7] +activation = "sigmoid" + +[[neurons]] +id = 12 +inputs = [1, 2, 3, 4, 5, 6, 7] +activation = "sigmoid" + +[[neurons]] +id = 13 +inputs = [8, 9, 10, 11, 12] +activation = "sigmoid" diff --git a/configs/networks/pole_balancing_6.toml b/configs/networks/pole_balancing_6.toml new file mode 100644 index 0000000..e9d132f --- /dev/null +++ b/configs/networks/pole_balancing_6.toml @@ -0,0 +1,38 @@ +input_ids = [1, 2, 3, 4, 5, 6] +output_ids = [14] +bias_id = 7 + +[[neurons]] +id = 8 +inputs = [1, 2, 3, 4, 5, 6, 7] +activation = "sigmoid" + +[[neurons]] +id = 9 +inputs = [1, 2, 3, 4, 5, 6, 7] +activation = "sigmoid" + +[[neurons]] +id = 10 +inputs = [1, 2, 3, 4, 5, 6, 7] +activation = "sigmoid" + +[[neurons]] +id = 11 +inputs = [1, 2, 3, 4, 5, 6, 7] +activation = "sigmoid" + +[[neurons]] +id = 12 +inputs = [1, 2, 3, 4, 5, 6, 7] +activation = "sigmoid" + +[[neurons]] +id = 13 +inputs = [1, 2, 3, 4, 5, 6, 7] +activation = "sigmoid" + +[[neurons]] +id = 14 +inputs = [8, 9, 10, 11, 12, 13] +activation = "sigmoid" diff --git a/src/gui.rs b/src/gui.rs index 813df52..f07dadc 100644 --- a/src/gui.rs +++ b/src/gui.rs @@ -271,9 +271,19 @@ impl State { impl ggez::event::EventHandler for State { fn update(&mut self, _ctx: &mut Context) -> GameResult { - if self.iteration < self.n_iters { - self.alg.optimization_step(&self.problem); - self.iteration += 1; + match &self.alg { + Algorithm::NeuralNetwork(_) => { + if self.iteration == 0 { + let generation = self.alg.optimize_with_early_stopping(&self.problem, self.n_iters, 2e-2, None); + self.iteration = generation; + } else {} + } + _ => { + if self.iteration < self.n_iters { + self.alg.optimization_step(&self.problem); + self.iteration += 1; + } + } } Ok(()) diff --git a/src/neural_network.rs b/src/neural_network.rs index f2688b5..05f6a0b 100644 --- a/src/neural_network.rs +++ b/src/neural_network.rs @@ -2,7 +2,7 @@ use rand::prelude::*; use serde_derive::Deserialize; use rand_distr::Normal; use std::collections::HashMap; -use cmaes::{DVector, CMAESOptions, Mode}; +use cmaes::{DVector, CMAESOptions, Mode, CMAES}; use crate::neuroevolution_algorithm::{NeuroevolutionAlgorithm, Algorithm}; pub type ActivationFunction = fn(f64) -> f64;