Skip to content

Commit

Permalink
added nn config
Browse files Browse the repository at this point in the history
  • Loading branch information
samyhaff committed Apr 27, 2024
1 parent b1f0c53 commit 600bc13
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 14 deletions.
4 changes: 2 additions & 2 deletions nn_cmaes_xor.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ bias_id = 3
[[neurons]]
id = 4
inputs = [1, 2, 3]
activation = "SIGMOID"
activation = "sigmoid"

[[neurons]]
id = 5
inputs = [1, 2, 3, 4]
activation = "SIGMOID"
activation = "sigmoid"
5 changes: 3 additions & 2 deletions src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ fn main() {
network_config_file.read_to_string(&mut toml_config).unwrap();

let network_config: NeuralNetworkConfig = toml::from_str(&toml_config).unwrap();
println!("{:?}", network_config);
panic!("Not implemented");

let network = network_config.to_neural_network();
alg = Algorithm::NeuralNetwork(network);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/bin/neural_network_cmaes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn main() {
);

let problem = Benchmark::new(Problem::Xor);
let mut alg = Algorithm::NeuralNetworek(network);
let mut alg = Algorithm::NeuralNetwork(network);
alg.optimize_cmaes(&problem);
println!("Fitness: {:.2}", problem.evaluate(&alg));
}
2 changes: 1 addition & 1 deletion src/gui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ impl State {
self.get_bend_decision_mesh(mesh, bias, angle, 0.1, 1., bend)?;
}

Algorithm::Neat(_) | Algorithm::NeuralNetworek(_) => {
Algorithm::Neat(_) | Algorithm::NeuralNetwork(_) => {
match &self.problem {
Benchmark::Classification(points) | Benchmark::SphereClassification(points) => {
// for now, draw outputs
Expand Down
31 changes: 29 additions & 2 deletions src/neural_network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,33 @@ impl Neuron {
}
}

impl NeuralNetworkConfig {
pub fn to_neural_network(&self) -> NeuralNetwork {
let mut neurons = Vec::new();
for neuron_config in self.neurons.iter() {
let activation = match neuron_config.activation.as_str() {
"sigmoid" => SIGMOID,
"identity" => IDENTITY,
_ => panic!("Unknown activation function"),
};

let mut inputs = Vec::new();
for input in neuron_config.inputs.iter() {
inputs.push(NeuronInput::new(*input, None));
}

neurons.push(Neuron::new(neuron_config.id, inputs, activation));
}

NeuralNetwork {
input_ids: self.input_ids.clone(),
output_ids: self.output_ids.clone(),
bias_id: self.bias_id,
neurons,
}
}
}

impl NeuralNetwork {
pub fn new(input_ids: Vec<u32>, output_ids: Vec<u32>, bias_id: Option<u32>, neurons: Vec<Neuron>) -> NeuralNetwork {
NeuralNetwork {
Expand Down Expand Up @@ -146,13 +173,13 @@ impl NeuralNetwork {

impl NeuroevolutionAlgorithm for NeuralNetwork {
fn optimization_step(&mut self, _problem: &crate::benchmarks::Benchmark) {
unimplemented!()
unimplemented!("Optimization step not implemented for NeuralNetwork");
}

fn optimize_cmaes(&mut self, problem: &crate::benchmarks::Benchmark) {
let eval_fn = |x: &DVector<f64>| {
let network = self.to_network(x);
problem.evaluate(&Algorithm::NeuralNetworek(network))
problem.evaluate(&Algorithm::NeuralNetwork(network))
};

let initial_connection_weights = self.to_vector();
Expand Down
12 changes: 6 additions & 6 deletions src/neuroevolution_algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub enum Algorithm {
ContinuousBNA(VNeuron),
Neat(Neat),
NeatIndividual(Individual),
NeuralNetworek(NeuralNetwork),
NeuralNetwork(NeuralNetwork),
}

impl std::fmt::Display for Algorithm {
Expand All @@ -35,7 +35,7 @@ impl std::fmt::Display for Algorithm {
Algorithm::DiscreteBNA(vneuron) => write!(f, "{}", vneuron),
Algorithm::ContinuousBNA(vneuron) => write!(f, "{}", vneuron),
Algorithm::Neat(neat) => write!(f, "{:?}", neat), // TODO: Implement Display for Neat
Algorithm::NeuralNetworek(network) => write!(f, "{:?}", network),
Algorithm::NeuralNetwork(network) => write!(f, "{:?}", network),
Algorithm::NeatIndividual(individual) => write!(f, "{:?}", individual),
}
}
Expand All @@ -49,7 +49,7 @@ impl NeuroevolutionAlgorithm for Algorithm {
Algorithm::DiscreteBNA(vneuron) => vneuron.optimize(problem, n_iters),
Algorithm::ContinuousBNA(vneuron) => vneuron.optimize(problem, n_iters),
Algorithm::Neat(neat) => neat.optimize(problem, n_iters),
Algorithm::NeuralNetworek(network) => network.optimize(problem, n_iters),
Algorithm::NeuralNetwork(network) => network.optimize(problem, n_iters),
Algorithm::NeatIndividual(individual) => individual.optimize(problem, n_iters),
}
}
Expand All @@ -61,7 +61,7 @@ impl NeuroevolutionAlgorithm for Algorithm {
Algorithm::DiscreteBNA(vneuron) => vneuron.optimize_cmaes(problem),
Algorithm::ContinuousBNA(vneuron) => vneuron.optimize_cmaes(problem),
Algorithm::Neat(neat) => neat.optimize_cmaes(problem),
Algorithm::NeuralNetworek(network) => network.optimize_cmaes(problem),
Algorithm::NeuralNetwork(network) => network.optimize_cmaes(problem),
Algorithm::NeatIndividual(individual) => individual.optimize_cmaes(problem),
}
}
Expand All @@ -73,7 +73,7 @@ impl NeuroevolutionAlgorithm for Algorithm {
Algorithm::DiscreteBNA(vneuron) => vneuron.evaluate(input),
Algorithm::ContinuousBNA(vneuron) => vneuron.evaluate(input),
Algorithm::Neat(neat) => neat.evaluate(input),
Algorithm::NeuralNetworek(network) => network.evaluate(input),
Algorithm::NeuralNetwork(network) => network.evaluate(input),
Algorithm::NeatIndividual(individual) => individual.evaluate(input),
}
}
Expand All @@ -85,7 +85,7 @@ impl NeuroevolutionAlgorithm for Algorithm {
Algorithm::DiscreteBNA(vneuron) => vneuron.optimization_step(problem),
Algorithm::ContinuousBNA(vneuron) => vneuron.optimization_step(problem),
Algorithm::Neat(neat) => neat.optimization_step(problem),
Algorithm::NeuralNetworek(network) => network.optimization_step(problem),
Algorithm::NeuralNetwork(network) => network.optimization_step(problem),
Algorithm::NeatIndividual(individual) => individual.optimization_step(problem),
}
}
Expand Down

0 comments on commit 600bc13

Please sign in to comment.