Skip to content

Commit

Permalink
added neat config file
Browse files Browse the repository at this point in the history
  • Loading branch information
samyhaff committed Apr 21, 2024
1 parent 07e6734 commit 3f677bf
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 36 deletions.
63 changes: 56 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ cmaes = "0.2.1"
ggez = "0.9.3"
rand = "0.8.5"
rand_distr = "0.4.3"
serde = "1.0.198"
serde_derive = "1.0.198"
toml = "0.8.12"
18 changes: 18 additions & 0 deletions neat_conf.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
population_size = 1000
n_inputs = 4
n_outputs = 1
weights_mean = 0.0
weights_stddev = 0.8
perturbation_stddev = 0.2
new_weight_probability = 0.1
enable_probability = 0.25
survival_threshold = 0.25
connection_mutation_rate = 0.3
node_mutation_rate = 0.03
weight_mutation_rate = 0.8
similarity_threshold = 15.0
excess_weight = 1.0
disjoint_weight = 1.0
matching_weight = 3.0
champion_copy_threshold = 4
stagnation_threshold = 1500
18 changes: 18 additions & 0 deletions neat_conf_2.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
population_size = 150
n_inputs = 2
n_outputs = 1
weights_mean = 0.0
weights_stddev = 0.8
perturbation_stddev = 0.2
new_weight_probability = 0.1
enable_probability = 0.25
survival_threshold = 0.25
connection_mutation_rate = 0.3
node_mutation_rate = 0.03
weight_mutation_rate = 0.8
similarity_threshold = 15.0
excess_weight = 1.0
disjoint_weight = 1.0
matching_weight = 0.3
champion_copy_threshold = 5
stagnation_threshold = 1500
68 changes: 41 additions & 27 deletions src/bin/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use std::fs::File;
use std::io::Read;
use toml;
use ggez::*;
use clap::Parser;
use neuroevolution::cli::*;
Expand All @@ -9,7 +11,6 @@ use neuroevolution::discrete_network::DiscreteNetwork;
use neuroevolution::neuroevolution_algorithm::{NeuroevolutionAlgorithm, Algorithm};
use neuroevolution::benchmarks::*;
use neuroevolution::constants::*;
use neuroevolution::gui::*;
use neuroevolution::neat::*;

fn main() {
Expand Down Expand Up @@ -51,26 +52,31 @@ fn main() {
}
}
AlgorithmType::Neat => {
let config = Config {
population_size: 150,
n_inputs: 2,
n_outputs: 1,
weights_mean: 0.,
weights_stddev: 0.8,
perturbation_stddev: 0.2,
new_weight_probability: 0.1,
enable_probability: 0.25,
survival_threshold: 0.25,
connection_mutation_rate: 0.3,
node_mutation_rate: 0.03,
weight_mutation_rate: 0.8,
similarity_threshold: 15.0,
excess_weight: 1.,
disjoint_weight: 1.,
matching_weight: 0.3,
champion_copy_threshold: 5,
stagnation_threshold: 1500,
};
// let config = Config {
// population_size: 150,
// n_inputs: 2,
// n_outputs: 1,
// weights_mean: 0.,
// weights_stddev: 0.8,
// perturbation_stddev: 0.2,
// new_weight_probability: 0.1,
// enable_probability: 0.25,
// survival_threshold: 0.25,
// connection_mutation_rate: 0.3,
// node_mutation_rate: 0.03,
// weight_mutation_rate: 0.8,
// similarity_threshold: 15.0,
// excess_weight: 1.,
// disjoint_weight: 1.,
// matching_weight: 0.3,
// champion_copy_threshold: 5,
// stagnation_threshold: 1500,
// };

let mut neat_config_file = File::open("neat_conf.toml").unwrap();
let mut toml_config = String::new();
neat_config_file.read_to_string(&mut toml_config).unwrap();
let config: Config = toml::from_str(&toml_config).unwrap();

let neat = Neat::new(config);
alg = Algorithm::Neat(neat);
Expand All @@ -79,21 +85,29 @@ fn main() {

match cli.gui {
true => {
let mut conf_file = File::open("gui_conf.toml").unwrap();
let conf = conf::Conf::from_toml_file(&mut conf_file).unwrap();
let cb = ContextBuilder::new("Neuroevolution", "Samy Haffoudhi") .default_conf(conf);
let (ctx, event_loop) = cb.build().unwrap();

match problem {
Benchmark::PoleBalancing => {
panic!("Not implemented yet!");
println!("Evolving algorithm...");
alg.optimize(&problem, cli.iterations);
println!("Fitness: {}", problem.evaluate(&alg));

let pole_balancing_state = neuroevolution::pole_balancing::State::default();
let state = neuroevolution::pole_balancing_gui::State::new(pole_balancing_state, alg);
event::run(ctx, event_loop, state);
}

_ => {
let state = State::new(alg, problem, N_ITERATIONS);
let mut conf_file = File::open("gui_conf.toml").unwrap();
let conf = conf::Conf::from_toml_file(&mut conf_file).unwrap();
let cb = ContextBuilder::new("Neuroevolution", "Samy Haffoudhi") .default_conf(conf);
let (ctx, event_loop) = cb.build().unwrap();
let state = neuroevolution::gui::State::new(alg, problem, N_ITERATIONS);
event::run(ctx, event_loop, state);
}
}
}

false => {
alg.optimize(&problem, cli.iterations);

Expand Down
2 changes: 1 addition & 1 deletion src/bin/pole_balancing_gui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fn main() {
let neat = Neat::new(config);
let mut alg = Algorithm::Neat(neat);
let problem = Benchmark::PoleBalancing;
println!("Optimizing algorithm...");
println!("Evolving algorithm...");
alg.optimize(&problem, 20);
println!("Fitness: {}", problem.evaluate(&alg));

Expand Down
3 changes: 2 additions & 1 deletion src/neat.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use rand::prelude::*;
use rand_distr::Normal;
use serde_derive::Deserialize;
use crate::neural_network::*;
use crate::neuroevolution_algorithm::*;
use crate::benchmarks::Benchmark;
Expand Down Expand Up @@ -57,7 +58,7 @@ struct History {
generation: u32,
}

#[derive(Debug)]
#[derive(Debug, Deserialize)]
pub struct Config {
pub population_size: u32,
pub n_inputs: u32,
Expand Down

0 comments on commit 3f677bf

Please sign in to comment.