Skip to content

Commit

Permalink
added support for visualizing cmaes output
Browse files Browse the repository at this point in the history
  • Loading branch information
samyhaff committed Jun 8, 2024
1 parent 4f16704 commit 93b5623
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 4 deletions.
33 changes: 33 additions & 0 deletions configs/networks/pole_balancing_5.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
input_ids = [1, 2, 3, 4, 5, 6]
output_ids = [13]
bias_id = 7

[[neurons]]
id = 8
inputs = [1, 2, 3, 4, 5, 6, 7]
activation = "sigmoid"

[[neurons]]
id = 9
inputs = [1, 2, 3, 4, 5, 6, 7]
activation = "sigmoid"

[[neurons]]
id = 10
inputs = [1, 2, 3, 4, 5, 6, 7]
activation = "sigmoid"

[[neurons]]
id = 11
inputs = [1, 2, 3, 4, 5, 6, 7]
activation = "sigmoid"

[[neurons]]
id = 12
inputs = [1, 2, 3, 4, 5, 6, 7]
activation = "sigmoid"

[[neurons]]
id = 13
inputs = [8, 9, 10, 11, 12]
activation = "sigmoid"
38 changes: 38 additions & 0 deletions configs/networks/pole_balancing_6.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
input_ids = [1, 2, 3, 4, 5, 6]
output_ids = [14]
bias_id = 7

[[neurons]]
id = 8
inputs = [1, 2, 3, 4, 5, 6, 7]
activation = "sigmoid"

[[neurons]]
id = 9
inputs = [1, 2, 3, 4, 5, 6, 7]
activation = "sigmoid"

[[neurons]]
id = 10
inputs = [1, 2, 3, 4, 5, 6, 7]
activation = "sigmoid"

[[neurons]]
id = 11
inputs = [1, 2, 3, 4, 5, 6, 7]
activation = "sigmoid"

[[neurons]]
id = 12
inputs = [1, 2, 3, 4, 5, 6, 7]
activation = "sigmoid"

[[neurons]]
id = 13
inputs = [1, 2, 3, 4, 5, 6, 7]
activation = "sigmoid"

[[neurons]]
id = 14
inputs = [8, 9, 10, 11, 12, 13]
activation = "sigmoid"
16 changes: 13 additions & 3 deletions src/gui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,19 @@ impl State {

impl ggez::event::EventHandler<GameError> for State {
fn update(&mut self, _ctx: &mut Context) -> GameResult {
if self.iteration < self.n_iters {
self.alg.optimization_step(&self.problem);
self.iteration += 1;
match &self.alg {
Algorithm::NeuralNetwork(_) => {
if self.iteration == 0 {
let generation = self.alg.optimize_with_early_stopping(&self.problem, self.n_iters, 2e-2, None);
self.iteration = generation;
} else {}
}
_ => {
if self.iteration < self.n_iters {
self.alg.optimization_step(&self.problem);
self.iteration += 1;
}
}
}

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion src/neural_network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use rand::prelude::*;
use serde_derive::Deserialize;
use rand_distr::Normal;
use std::collections::HashMap;
use cmaes::{DVector, CMAESOptions, Mode};
use cmaes::{DVector, CMAESOptions, Mode, CMAES};
use crate::neuroevolution_algorithm::{NeuroevolutionAlgorithm, Algorithm};

pub type ActivationFunction = fn(f64) -> f64;
Expand Down

0 comments on commit 93b5623

Please sign in to comment.