From 40ad3e3950ab954634407c5d97fc27684744af21 Mon Sep 17 00:00:00 2001 From: samyhaff Date: Thu, 9 May 2024 15:45:05 +0200 Subject: [PATCH] updated alg init logic in main --- src/bin/main.rs | 99 +++++++++++++++++++------------------------------ testing_job.sh | 2 +- 2 files changed, 40 insertions(+), 61 deletions(-) diff --git a/src/bin/main.rs b/src/bin/main.rs index 4366618..901aa9f 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -17,78 +17,57 @@ use neuroevolution::constants::*; use neuroevolution::neat::*; use neuroevolution::neural_network::NeuralNetworkConfig; -fn main() { - let cli = Cli::parse(); - let mut alg: Algorithm; - let dim = 2; - - let problem = Benchmark::new(cli.problem); - - match cli.algorithm { - AlgorithmType::Oneplusonena => { - // match cli.continuous { - // true => { - // let network = Network::new(cli.neurons, dim); - // alg = Algorithm::ContinuousOneplusoneNA(network); - // } - // false => { - // let network = DiscreteNetwork::new(cli.resolution, cli.neurons, dim); - // alg = Algorithm::DiscreteOneplusoneNA(network); - // } - // } - - let network = DiscreteNetwork::new(cli.resolution, cli.neurons, dim); - alg = Algorithm::DiscreteOneplusoneNA(network); - }, - AlgorithmType::Bna => { - // match cli.continuous { - // true => { - // let vneuron = VNeuron::new(dim); - // alg = Algorithm::ContinuousBNA(vneuron); - // }, - // false => { - // let vneuron = DiscreteVNeuron::new(cli.resolution, dim); - // alg = Algorithm::DiscreteBNA(vneuron); - // } - // } - - let vneuron = DiscreteVNeuron::new(cli.resolution, dim); - alg = Algorithm::DiscreteBNA(vneuron); - } +fn get_algorithm(algorithm_type: AlgorithmType, resolution: usize, neurons: usize, dim: usize, toml_config: &Option) -> Algorithm { + match algorithm_type { AlgorithmType::Neat => { - let config_file_path = match cli.file { - Some(file) => file, - None => panic!("No configuration file provided"), - }; - - let mut neat_config_file = File::open(config_file_path).unwrap(); - let mut toml_config = String::new(); - neat_config_file.read_to_string(&mut toml_config).unwrap(); - let config: Config = toml::from_str(&toml_config).unwrap(); + let toml_config: &str = toml_config.as_deref().unwrap(); + let config: Config = toml::from_str(toml_config).unwrap(); let neat = Neat::new(config); - alg = Algorithm::Neat(neat); + Algorithm::Neat(neat) } AlgorithmType::NeuralNetwork => { - let config_file_path = match cli.file { - Some(file) => file, - None => panic!("No configuration file provided"), - }; - - let mut network_config_file = File::open(config_file_path).unwrap(); - let mut toml_config = String::new(); - network_config_file.read_to_string(&mut toml_config).unwrap(); - + let toml_config: &str = toml_config.as_deref().unwrap(); let network_config: NeuralNetworkConfig = toml::from_str(&toml_config).unwrap(); let network = network_config.to_neural_network(); - alg = Algorithm::NeuralNetwork(network); + Algorithm::NeuralNetwork(network) + } + AlgorithmType::Oneplusonena => { + let network = DiscreteNetwork::new(resolution, neurons, dim); + Algorithm::DiscreteOneplusoneNA(network) + } + AlgorithmType::Bna => { + let vneuron = DiscreteVNeuron::new(resolution, dim); + Algorithm::DiscreteBNA(vneuron) } } +} + +fn main() { + let cli = Cli::parse(); + + let dim = 2; + let problem = Benchmark::new(cli.problem); + let resolution = cli.resolution; + let neurons = cli.neurons; + + let mut alg: Algorithm; + + let toml_config = if let Some(file) = cli.file { + let mut config_file = File::open(file).unwrap(); + let mut toml_config = String::new(); + config_file.read_to_string(&mut toml_config).unwrap(); + Some(toml_config) + } else { + None + }; + + alg = get_algorithm(cli.algorithm, resolution, neurons, dim, &toml_config); if let Some(n_runs) = cli.test_runs { let results = (0..n_runs).into_par_iter().map(|_| { - let mut algorithm = alg.clone(); // TODO initialize with different initial values + let mut algorithm = get_algorithm(cli.algorithm, resolution, neurons, dim, &toml_config); let problem = Benchmark::new(cli.problem); let start = Instant::now(); @@ -100,7 +79,7 @@ fn main() { if let Some(output_path) = cli.output { let mut output_file = File::create(output_path).unwrap(); - writeln!(output_file, "Fitness,Iterations,Elapsed time").unwrap(); + writeln!(output_file, "fitness,iterations,cpu").unwrap(); for (fitness, n_iters, elapsed) in results { writeln!(output_file, "{:.2},{},{:.3}", fitness, n_iters, elapsed).unwrap(); } diff --git a/testing_job.sh b/testing_job.sh index 455c88b..00bd8e7 100755 --- a/testing_job.sh +++ b/testing_job.sh @@ -34,5 +34,5 @@ cd ~/code-master for resolution in $(seq 100 100 1000) do - ./target/release/main oneplusonena half -i 500 -n 1 -r $resolution -t $n_runs -o output/oneplusone_na_half_$resolution.csv + ./target/release/main oneplusonena half -i 500 -n 1 -r $resolution -t $n_runs -o ~/output/oneplusone_na_half_$resolution.csv done