Skip to content

Commit

Permalink
updated alg init logic in main
Browse files Browse the repository at this point in the history
  • Loading branch information
samyhaff committed May 9, 2024
1 parent 9068d81 commit 40ad3e3
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 61 deletions.
99 changes: 39 additions & 60 deletions src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,78 +17,57 @@ use neuroevolution::constants::*;
use neuroevolution::neat::*;
use neuroevolution::neural_network::NeuralNetworkConfig;

fn main() {
let cli = Cli::parse();
let mut alg: Algorithm;
let dim = 2;

let problem = Benchmark::new(cli.problem);

match cli.algorithm {
AlgorithmType::Oneplusonena => {
// match cli.continuous {
// true => {
// let network = Network::new(cli.neurons, dim);
// alg = Algorithm::ContinuousOneplusoneNA(network);
// }
// false => {
// let network = DiscreteNetwork::new(cli.resolution, cli.neurons, dim);
// alg = Algorithm::DiscreteOneplusoneNA(network);
// }
// }

let network = DiscreteNetwork::new(cli.resolution, cli.neurons, dim);
alg = Algorithm::DiscreteOneplusoneNA(network);
},
AlgorithmType::Bna => {
// match cli.continuous {
// true => {
// let vneuron = VNeuron::new(dim);
// alg = Algorithm::ContinuousBNA(vneuron);
// },
// false => {
// let vneuron = DiscreteVNeuron::new(cli.resolution, dim);
// alg = Algorithm::DiscreteBNA(vneuron);
// }
// }

let vneuron = DiscreteVNeuron::new(cli.resolution, dim);
alg = Algorithm::DiscreteBNA(vneuron);
}
fn get_algorithm(algorithm_type: AlgorithmType, resolution: usize, neurons: usize, dim: usize, toml_config: &Option<String>) -> Algorithm {
match algorithm_type {
AlgorithmType::Neat => {
let config_file_path = match cli.file {
Some(file) => file,
None => panic!("No configuration file provided"),
};

let mut neat_config_file = File::open(config_file_path).unwrap();
let mut toml_config = String::new();
neat_config_file.read_to_string(&mut toml_config).unwrap();
let config: Config = toml::from_str(&toml_config).unwrap();
let toml_config: &str = toml_config.as_deref().unwrap();
let config: Config = toml::from_str(toml_config).unwrap();

let neat = Neat::new(config);
alg = Algorithm::Neat(neat);
Algorithm::Neat(neat)
}
AlgorithmType::NeuralNetwork => {
let config_file_path = match cli.file {
Some(file) => file,
None => panic!("No configuration file provided"),
};

let mut network_config_file = File::open(config_file_path).unwrap();
let mut toml_config = String::new();
network_config_file.read_to_string(&mut toml_config).unwrap();

let toml_config: &str = toml_config.as_deref().unwrap();
let network_config: NeuralNetworkConfig = toml::from_str(&toml_config).unwrap();

let network = network_config.to_neural_network();
alg = Algorithm::NeuralNetwork(network);
Algorithm::NeuralNetwork(network)
}
AlgorithmType::Oneplusonena => {
let network = DiscreteNetwork::new(resolution, neurons, dim);
Algorithm::DiscreteOneplusoneNA(network)
}
AlgorithmType::Bna => {
let vneuron = DiscreteVNeuron::new(resolution, dim);
Algorithm::DiscreteBNA(vneuron)
}
}
}

fn main() {
let cli = Cli::parse();

let dim = 2;
let problem = Benchmark::new(cli.problem);
let resolution = cli.resolution;
let neurons = cli.neurons;

let mut alg: Algorithm;

let toml_config = if let Some(file) = cli.file {
let mut config_file = File::open(file).unwrap();
let mut toml_config = String::new();
config_file.read_to_string(&mut toml_config).unwrap();
Some(toml_config)
} else {
None
};

alg = get_algorithm(cli.algorithm, resolution, neurons, dim, &toml_config);

if let Some(n_runs) = cli.test_runs {
let results = (0..n_runs).into_par_iter().map(|_| {
let mut algorithm = alg.clone(); // TODO initialize with different initial values
let mut algorithm = get_algorithm(cli.algorithm, resolution, neurons, dim, &toml_config);
let problem = Benchmark::new(cli.problem);

let start = Instant::now();
Expand All @@ -100,7 +79,7 @@ fn main() {

if let Some(output_path) = cli.output {
let mut output_file = File::create(output_path).unwrap();
writeln!(output_file, "Fitness,Iterations,Elapsed time").unwrap();
writeln!(output_file, "fitness,iterations,cpu").unwrap();
for (fitness, n_iters, elapsed) in results {
writeln!(output_file, "{:.2},{},{:.3}", fitness, n_iters, elapsed).unwrap();
}
Expand Down
2 changes: 1 addition & 1 deletion testing_job.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ cd ~/code-master

for resolution in $(seq 100 100 1000)
do
./target/release/main oneplusonena half -i 500 -n 1 -r $resolution -t $n_runs -o output/oneplusone_na_half_$resolution.csv
./target/release/main oneplusonena half -i 500 -n 1 -r $resolution -t $n_runs -o ~/output/oneplusone_na_half_$resolution.csv
done

0 comments on commit 40ad3e3

Please sign in to comment.