Skip to content

Commit

Permalink
added support for testing cmaes
Browse files Browse the repository at this point in the history
  • Loading branch information
samyhaff committed May 25, 2024
1 parent 7bbd9a8 commit 1ed2e97
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 8 deletions.
File renamed without changes.
File renamed without changes.
18 changes: 18 additions & 0 deletions configs/neat_conf_proben1.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
population_size = 300
n_inputs = 9
n_outputs = 1
weights_mean = 0.0
weights_stddev = 0.8
perturbation_stddev = 0.2
new_weight_probability = 0.1
enable_probability = 0.25
survival_threshold = 0.25
connection_mutation_rate = 0.3
node_mutation_rate = 0.03
weight_mutation_rate = 0.8
similarity_threshold = 7.0
excess_weight = 1.0
disjoint_weight = 1.0
matching_weight = 0.3
champion_copy_threshold = 5
stagnation_threshold = 1500
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion src/bin/gui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fn main() {

let state = State::new(alg, quarter, N_ITERATIONS);

let mut conf_file = File::open("gui_conf.toml").unwrap();
let mut conf_file = File::open("configs/gui_conf.toml").unwrap();
let conf = conf::Conf::from_toml_file(&mut conf_file).unwrap();

let cb = ContextBuilder::new("Neuroevolution", "Samy Haffoudhi") .default_conf(conf);
Expand Down
14 changes: 11 additions & 3 deletions src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ fn main() {

match cli.gui {
true => {
let mut conf_file = File::open("gui_conf.toml").unwrap();
let mut conf_file = File::open("configs/gui_conf.toml").unwrap();
let conf = conf::Conf::from_toml_file(&mut conf_file).unwrap();
let cb = ContextBuilder::new("Neuroevolution", "Samy Haffoudhi") .default_conf(conf);
let (ctx, event_loop) = cb.build().unwrap();
Expand All @@ -122,8 +122,16 @@ fn main() {
}

false => {
let n_iters = alg.optimize_with_early_stopping(&problem, cli.iterations, cli.error_tol, None);
println!("Iterations: {}\nFitness: {:.2}", n_iters, problem.test(&alg));
match alg {
Algorithm::NeuralNetwork(_) => {
alg.optimize(&problem, iterations);
println!("Fitness: {:.2}", problem.test(&alg));
},
_ => {
let n_iters = alg.optimize_with_early_stopping(&problem, cli.iterations, cli.error_tol, None);
println!("Iterations: {}\nFitness: {:.2}", n_iters, problem.test(&alg));
}
}
}
}
}
2 changes: 1 addition & 1 deletion src/bin/pole_balancing_gui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ fn main() {

let state = State::new(pole_balancing_state, alg);

let mut conf_file = File::open("gui_conf.toml").unwrap();
let mut conf_file = File::open("configs/gui_conf.toml").unwrap();
let conf = conf::Conf::from_toml_file(&mut conf_file).unwrap();

let cb = ContextBuilder::new("Neuroevolution", "Samy Haffoudhi") .default_conf(conf);
Expand Down
39 changes: 36 additions & 3 deletions src/neural_network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use rand::prelude::*;
use serde_derive::Deserialize;
use rand_distr::Normal;
use std::collections::HashMap;
use cmaes::{DVector, fmax};
use cmaes::{DVector, CMAESOptions, Mode};
use crate::neuroevolution_algorithm::{NeuroevolutionAlgorithm, Algorithm};

pub type ActivationFunction = fn(f64) -> f64;
Expand Down Expand Up @@ -176,15 +176,48 @@ impl NeuroevolutionAlgorithm for NeuralNetwork {
unimplemented!("Optimization step not implemented for NeuralNetwork");
}

fn optimize_with_early_stopping(&mut self, problem: &crate::benchmarks::Benchmark, _max_iters: u32, _fitness_tol: f64, _max_stagnation: Option<u32>) -> u32 where Self: Sized {
let eval_fn = |x: &DVector<f64>| {
let network = self.to_network(x);
problem.evaluate(&Algorithm::NeuralNetwork(network))
};

let initial_connection_weights = self.to_vector();

let mut cmaes_state = CMAESOptions::new(initial_connection_weights, 0.4)
.mode(Mode::Maximize)
.build(eval_fn)
.unwrap();

let _ = cmaes_state.run();
let Some(best_individual) = cmaes_state.overall_best_individual() else {
panic!("No best individual found");
};
let generation = cmaes_state.generation() as u32;

*self = self.to_network(&best_individual.point);
generation
}

fn optimize_cmaes(&mut self, problem: &crate::benchmarks::Benchmark) {
let eval_fn = |x: &DVector<f64>| {
let network = self.to_network(x);
problem.evaluate(&Algorithm::NeuralNetwork(network))
};

let initial_connection_weights = self.to_vector();
let solution = fmax(eval_fn, initial_connection_weights, 0.4);
*self = self.to_network(&solution.point);

let mut cmaes_state = CMAESOptions::new(initial_connection_weights, 0.4)
.mode(Mode::Maximize)
.build(eval_fn)
.unwrap();

let _ = cmaes_state.run();
let Some(best_individual) = cmaes_state.overall_best_individual() else {
panic!("No best individual found");
};

*self = self.to_network(&best_individual.point);
}

fn evaluate(&self, input: &Vec<f64>) -> f64 {
Expand Down
12 changes: 12 additions & 0 deletions src/neuroevolution_algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,16 @@ impl NeuroevolutionAlgorithm for Algorithm {
Algorithm::NeatIndividual(individual) => individual.optimization_step(problem),
}
}

fn optimize_with_early_stopping(&mut self, problem: &Benchmark, max_iters: u32, fitness_tol: f64, max_stagnation: Option<u32>) -> u32 where Self: Sized {
match self {
Algorithm::DiscreteOneplusoneNA(network) => network.optimize_with_early_stopping(problem, max_iters, fitness_tol, max_stagnation),
Algorithm::ContinuousOneplusoneNA(network) => network.optimize_with_early_stopping(problem, max_iters, fitness_tol, max_stagnation),
Algorithm::DiscreteBNA(vnetwork) => vnetwork.optimize_with_early_stopping(problem, max_iters, fitness_tol, max_stagnation),
Algorithm::ContinuousBNA(vneuron) => vneuron.optimize_with_early_stopping(problem, max_iters, fitness_tol, max_stagnation),
Algorithm::Neat(neat) => neat.optimize_with_early_stopping(problem, max_iters, fitness_tol, max_stagnation),
Algorithm::NeuralNetwork(network) => network.optimize_with_early_stopping(problem, max_iters, fitness_tol, max_stagnation),
Algorithm::NeatIndividual(individual) => individual.optimize_with_early_stopping(problem, max_iters, fitness_tol, max_stagnation),
}
}
}

0 comments on commit 1ed2e97

Please sign in to comment.