From 3f677bfe397c2e503962d96fc3276db6e308cef5 Mon Sep 17 00:00:00 2001 From: samyhaff Date: Sun, 21 Apr 2024 14:53:09 +0200 Subject: [PATCH] added neat config file --- Cargo.lock | 63 ++++++++++++++++++++++++++++---- Cargo.toml | 3 ++ neat_conf.toml | 18 ++++++++++ neat_conf_2.toml | 18 ++++++++++ src/bin/main.rs | 68 +++++++++++++++++++++-------------- src/bin/pole_balancing_gui.rs | 2 +- src/neat.rs | 3 +- 7 files changed, 139 insertions(+), 36 deletions(-) create mode 100644 neat_conf.toml create mode 100644 neat_conf_2.toml diff --git a/Cargo.lock b/Cargo.lock index c89a06f..c4272cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -885,7 +885,7 @@ dependencies = [ "serde", "skeptic", "smart-default", - "toml", + "toml 0.5.11", "typed-arena", "wgpu", "winit", @@ -1660,6 +1660,9 @@ dependencies = [ "ggez", "rand", "rand_distr", + "serde", + "serde_derive", + "toml 0.8.12", ] [[package]] @@ -2345,18 +2348,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.197" +version = "1.0.198" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" +checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.197" +version = "1.0.198" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" +checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" dependencies = [ "proc-macro2", "quote", @@ -2374,6 +2377,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb3622f419d1296904700073ea6cc23ad690adbd66f13ea683df73298736f0c1" +dependencies = [ + "serde", +] + [[package]] name = "shlex" version = "1.3.0" @@ -2684,11 +2696,26 @@ dependencies = [ "serde", ] +[[package]] +name = "toml" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9dd1545e8208b4a5af1aa9bbd0b4cf7e9ea08fabc5d0a5c67fcaafa17433aa3" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit 0.22.12", +] + [[package]] name = "toml_datetime" version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" +dependencies = [ + "serde", +] [[package]] name = "toml_edit" @@ -2698,7 +2725,7 @@ checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ "indexmap 2.2.5", "toml_datetime", - "winnow", + "winnow 0.5.40", ] [[package]] @@ -2709,7 +2736,20 @@ checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1" dependencies = [ "indexmap 2.2.5", "toml_datetime", - "winnow", + "winnow 0.5.40", +] + +[[package]] +name = "toml_edit" +version = "0.22.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3328d4f68a705b2a4498da1d580585d39a6510f98318a2cec3018a7ec61ddef" +dependencies = [ + "indexmap 2.2.5", + "serde", + "serde_spanned", + "toml_datetime", + "winnow 0.6.6", ] [[package]] @@ -3391,6 +3431,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "winnow" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c976aaaa0e1f90dbb21e9587cdaf1d9679a1cde8875c0d6bd83ab96a208352" +dependencies = [ + "memchr", +] + [[package]] name = "x11-dl" version = "2.21.0" diff --git a/Cargo.toml b/Cargo.toml index 78c75fc..39bbad6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,3 +14,6 @@ cmaes = "0.2.1" ggez = "0.9.3" rand = "0.8.5" rand_distr = "0.4.3" +serde = "1.0.198" +serde_derive = "1.0.198" +toml = "0.8.12" diff --git a/neat_conf.toml b/neat_conf.toml new file mode 100644 index 0000000..2ce0473 --- /dev/null +++ b/neat_conf.toml @@ -0,0 +1,18 @@ +population_size = 1000 +n_inputs = 4 +n_outputs = 1 +weights_mean = 0.0 +weights_stddev = 0.8 +perturbation_stddev = 0.2 +new_weight_probability = 0.1 +enable_probability = 0.25 +survival_threshold = 0.25 +connection_mutation_rate = 0.3 +node_mutation_rate = 0.03 +weight_mutation_rate = 0.8 +similarity_threshold = 15.0 +excess_weight = 1.0 +disjoint_weight = 1.0 +matching_weight = 3.0 +champion_copy_threshold = 4 +stagnation_threshold = 1500 diff --git a/neat_conf_2.toml b/neat_conf_2.toml new file mode 100644 index 0000000..c3124d3 --- /dev/null +++ b/neat_conf_2.toml @@ -0,0 +1,18 @@ +population_size = 150 +n_inputs = 2 +n_outputs = 1 +weights_mean = 0.0 +weights_stddev = 0.8 +perturbation_stddev = 0.2 +new_weight_probability = 0.1 +enable_probability = 0.25 +survival_threshold = 0.25 +connection_mutation_rate = 0.3 +node_mutation_rate = 0.03 +weight_mutation_rate = 0.8 +similarity_threshold = 15.0 +excess_weight = 1.0 +disjoint_weight = 1.0 +matching_weight = 0.3 +champion_copy_threshold = 5 +stagnation_threshold = 1500 diff --git a/src/bin/main.rs b/src/bin/main.rs index 7ba4acd..b8c4ae0 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -1,4 +1,6 @@ use std::fs::File; +use std::io::Read; +use toml; use ggez::*; use clap::Parser; use neuroevolution::cli::*; @@ -9,7 +11,6 @@ use neuroevolution::discrete_network::DiscreteNetwork; use neuroevolution::neuroevolution_algorithm::{NeuroevolutionAlgorithm, Algorithm}; use neuroevolution::benchmarks::*; use neuroevolution::constants::*; -use neuroevolution::gui::*; use neuroevolution::neat::*; fn main() { @@ -51,26 +52,31 @@ fn main() { } } AlgorithmType::Neat => { - let config = Config { - population_size: 150, - n_inputs: 2, - n_outputs: 1, - weights_mean: 0., - weights_stddev: 0.8, - perturbation_stddev: 0.2, - new_weight_probability: 0.1, - enable_probability: 0.25, - survival_threshold: 0.25, - connection_mutation_rate: 0.3, - node_mutation_rate: 0.03, - weight_mutation_rate: 0.8, - similarity_threshold: 15.0, - excess_weight: 1., - disjoint_weight: 1., - matching_weight: 0.3, - champion_copy_threshold: 5, - stagnation_threshold: 1500, - }; + // let config = Config { + // population_size: 150, + // n_inputs: 2, + // n_outputs: 1, + // weights_mean: 0., + // weights_stddev: 0.8, + // perturbation_stddev: 0.2, + // new_weight_probability: 0.1, + // enable_probability: 0.25, + // survival_threshold: 0.25, + // connection_mutation_rate: 0.3, + // node_mutation_rate: 0.03, + // weight_mutation_rate: 0.8, + // similarity_threshold: 15.0, + // excess_weight: 1., + // disjoint_weight: 1., + // matching_weight: 0.3, + // champion_copy_threshold: 5, + // stagnation_threshold: 1500, + // }; + + let mut neat_config_file = File::open("neat_conf.toml").unwrap(); + let mut toml_config = String::new(); + neat_config_file.read_to_string(&mut toml_config).unwrap(); + let config: Config = toml::from_str(&toml_config).unwrap(); let neat = Neat::new(config); alg = Algorithm::Neat(neat); @@ -79,21 +85,29 @@ fn main() { match cli.gui { true => { + let mut conf_file = File::open("gui_conf.toml").unwrap(); + let conf = conf::Conf::from_toml_file(&mut conf_file).unwrap(); + let cb = ContextBuilder::new("Neuroevolution", "Samy Haffoudhi") .default_conf(conf); + let (ctx, event_loop) = cb.build().unwrap(); + match problem { Benchmark::PoleBalancing => { - panic!("Not implemented yet!"); + println!("Evolving algorithm..."); + alg.optimize(&problem, cli.iterations); + println!("Fitness: {}", problem.evaluate(&alg)); + + let pole_balancing_state = neuroevolution::pole_balancing::State::default(); + let state = neuroevolution::pole_balancing_gui::State::new(pole_balancing_state, alg); + event::run(ctx, event_loop, state); } _ => { - let state = State::new(alg, problem, N_ITERATIONS); - let mut conf_file = File::open("gui_conf.toml").unwrap(); - let conf = conf::Conf::from_toml_file(&mut conf_file).unwrap(); - let cb = ContextBuilder::new("Neuroevolution", "Samy Haffoudhi") .default_conf(conf); - let (ctx, event_loop) = cb.build().unwrap(); + let state = neuroevolution::gui::State::new(alg, problem, N_ITERATIONS); event::run(ctx, event_loop, state); } } } + false => { alg.optimize(&problem, cli.iterations); diff --git a/src/bin/pole_balancing_gui.rs b/src/bin/pole_balancing_gui.rs index debd3cf..f9151ab 100644 --- a/src/bin/pole_balancing_gui.rs +++ b/src/bin/pole_balancing_gui.rs @@ -32,7 +32,7 @@ fn main() { let neat = Neat::new(config); let mut alg = Algorithm::Neat(neat); let problem = Benchmark::PoleBalancing; - println!("Optimizing algorithm..."); + println!("Evolving algorithm..."); alg.optimize(&problem, 20); println!("Fitness: {}", problem.evaluate(&alg)); diff --git a/src/neat.rs b/src/neat.rs index 46b0420..1349712 100644 --- a/src/neat.rs +++ b/src/neat.rs @@ -1,5 +1,6 @@ use rand::prelude::*; use rand_distr::Normal; +use serde_derive::Deserialize; use crate::neural_network::*; use crate::neuroevolution_algorithm::*; use crate::benchmarks::Benchmark; @@ -57,7 +58,7 @@ struct History { generation: u32, } -#[derive(Debug)] +#[derive(Debug, Deserialize)] pub struct Config { pub population_size: u32, pub n_inputs: u32,