Skip to content

Commit

Permalink
added neural network cmaes algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
samyhaff committed Apr 25, 2024
1 parent 2497c08 commit a8de79a
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 16 deletions.
3 changes: 3 additions & 0 deletions src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ fn main() {
let neat = Neat::new(config);
alg = Algorithm::Neat(neat);
}
AlgorithmType::NeuralNetwork => {
panic!("Not implemented yet!");
}
}

match cli.gui {
Expand Down
21 changes: 21 additions & 0 deletions src/bin/neural_network_cmaes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use neuroevolution::neural_network::*;
use neuroevolution::benchmarks::*;
use neuroevolution::neuroevolution_algorithm::Algorithm;
use neuroevolution::neuroevolution_algorithm::NeuroevolutionAlgorithm;

fn main() {
let network = NeuralNetwork::new(
vec![1, 2],
vec![5],
Some(3),
vec![
Neuron::new(4, vec![NeuronInput::new(1, None), NeuronInput::new(2, None), NeuronInput::new(3, None)], SIGMOID),
Neuron::new(5, vec![NeuronInput::new(4, None), NeuronInput::new(3, None), NeuronInput::new(2, None), NeuronInput::new(3, None)], SIGMOID),
]
);

let problem = Benchmark::new(Problem::Xor);
let mut alg = Algorithm::NeuralNetworek(network);
alg.optimize_cmaes(&problem);
println!("Fitness: {:.2}", problem.evaluate(&alg));
}
2 changes: 1 addition & 1 deletion src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ pub enum AlgorithmType {
Oneplusonena,
Bna,
Neat,
NeuralNetwork,
}

4 changes: 2 additions & 2 deletions src/gui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,12 @@ impl State {
self.get_bend_decision_mesh(mesh, bias, angle, 0.1, 1., bend)?;
}

Algorithm::Neat(neat) => {
Algorithm::Neat(_) | Algorithm::NeuralNetworek(_) => {
match &self.problem {
Benchmark::Classification(points) | Benchmark::SphereClassification(points) => {
// for now, draw outputs
for (point, _) in points {
let output = neat.evaluate(&point);
let output = self.alg.evaluate(&point);
// gradient from red to green
let color = graphics::Color::new(
1.0 - output as f32,
Expand Down
4 changes: 2 additions & 2 deletions src/neat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,14 +344,14 @@ impl Individual {

// Add hidden neurons
for node in self.genome.nodes.iter().filter(|n| n.layer == NodeType::Hidden) {
let inputs = self.genome.connections.iter().filter(|c| c.out_node == node.id && c.enabled).map(|c| NeuronInput::new(c.in_node, c.weight)).collect::<Vec<_>>();
let inputs = self.genome.connections.iter().filter(|c| c.out_node == node.id && c.enabled).map(|c| NeuronInput::new(c.in_node, Some(c.weight))).collect::<Vec<_>>();
let neuron = Neuron::new(node.id, inputs, node.activation);
neurons.push(neuron);
}

// Add output neurons
for output_id in output_ids.iter() {
let inputs = self.genome.connections.iter().filter(|c| c.out_node == *output_id && c.enabled).map(|c| NeuronInput::new(c.in_node, c.weight)).collect::<Vec<_>>();
let inputs = self.genome.connections.iter().filter(|c| c.out_node == *output_id && c.enabled).map(|c| NeuronInput::new(c.in_node, Some(c.weight))).collect::<Vec<_>>();
let neuron = Neuron::new(*output_id, inputs, *output_activations.next().unwrap());
neurons.push(neuron);
}
Expand Down
84 changes: 73 additions & 11 deletions src/neural_network.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
use rand::prelude::*;
use rand_distr::Normal;
use std::collections::HashMap;
use cmaes::{DVector, fmax};
use crate::neuroevolution_algorithm::{NeuroevolutionAlgorithm, Algorithm};

pub type ActivationFunction = fn(f64) -> f64;

pub const SIGMOID: ActivationFunction = |x| 1. / (1. + (-4.9 * x).exp());
pub const IDENTITY: ActivationFunction = |x| x;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct NeuronInput {
input_id: u32,
weight: f64,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Neuron {
id: u32,
inputs: Vec<NeuronInput>,
activation: ActivationFunction,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct NeuralNetwork {
input_ids: Vec<u32>,
output_ids: Vec<u32>,
Expand All @@ -27,10 +31,18 @@ pub struct NeuralNetwork {
}

impl NeuronInput {
pub fn new(input_id: u32, weight: f64) -> NeuronInput {
NeuronInput {
input_id,
weight,
pub fn new(input_id: u32, weight: Option<f64>) -> NeuronInput {
if let Some(weight) = weight {
NeuronInput {
input_id,
weight,
}
} else {
let weights_distribution = Normal::new(0., 0.8).unwrap();
NeuronInput {
input_id,
weight: weights_distribution.sample(&mut thread_rng()),
}
}
}
}
Expand Down Expand Up @@ -90,6 +102,56 @@ impl NeuralNetwork {

outputs
}

fn to_vector(&self) -> Vec<f64> {
let mut connection_weights = Vec::new();
for neuron in self.neurons.iter() {
for connection in &neuron.inputs {
connection_weights.push(connection.weight);
}
}

connection_weights
}

fn to_network(&self, connection_weights: &DVector::<f64>) -> NeuralNetwork {
let mut network = self.clone();
let mut conn_count = 0;
for neuron in network.neurons.iter_mut() {
for connection in &mut neuron.inputs {
connection.weight = connection_weights[conn_count];
conn_count += 1;
}
}

network
}
}

impl NeuroevolutionAlgorithm for NeuralNetwork {
fn optimization_step(&mut self, _problem: &crate::benchmarks::Benchmark) {
unimplemented!()
}

fn optimize_cmaes(&mut self, problem: &crate::benchmarks::Benchmark) {
let eval_fn = |x: &DVector<f64>| {
let network = self.to_network(x);
problem.evaluate(&Algorithm::NeuralNetworek(network))
};

let initial_connection_weights = self.to_vector();
let solution = fmax(eval_fn, initial_connection_weights, 0.4);
*self = self.to_network(&solution.point);
}

fn evaluate(&self, input: &Vec<f64>) -> f64 {
let output = self.feed_forward(input);
output[0]
}

fn optimize(&mut self, problem: &crate::benchmarks::Benchmark, _n_iters: u32) {
self.optimize_cmaes(problem);
}
}

#[cfg(test)]
Expand All @@ -103,7 +165,7 @@ mod tests {
let neurons = vec![
Neuron::new(1, vec![], IDENTITY),
Neuron::new(2, vec![], IDENTITY),
Neuron::new(3, vec![NeuronInput::new(1, 0.5), NeuronInput::new(2, 0.5)], IDENTITY),
Neuron::new(3, vec![NeuronInput::new(1, Some(0.5)), NeuronInput::new(2, Some(0.5))], IDENTITY),
];
let network = NeuralNetwork::new(input_ids, output_ids, None, neurons);

Expand All @@ -120,8 +182,8 @@ mod tests {
let neurons = vec![
Neuron::new(1, vec![] , IDENTITY),
Neuron::new(2, vec![], IDENTITY),
Neuron::new(3, vec![NeuronInput::new(1, 0.5), NeuronInput::new(2, 0.5)], IDENTITY),
Neuron::new(4, vec![NeuronInput::new(3, 0.5)], IDENTITY),
Neuron::new(3, vec![NeuronInput::new(1, Some(0.5)), NeuronInput::new(2, Some(0.5))], IDENTITY),
Neuron::new(4, vec![NeuronInput::new(3, Some(0.5))], IDENTITY),
];
let network = NeuralNetwork::new(input_ids, output_ids, None, neurons);

Expand All @@ -138,7 +200,7 @@ mod tests {
let neurons = vec![
Neuron::new(1, vec![], IDENTITY),
Neuron::new(2, vec![], IDENTITY),
Neuron::new(3, vec![NeuronInput::new(1, 0.5), NeuronInput::new(2, 0.5), NeuronInput::new(4, 1.)], IDENTITY),
Neuron::new(3, vec![NeuronInput::new(1, Some(0.5)), NeuronInput::new(2, Some(0.5)), NeuronInput::new(4, Some(1.))], IDENTITY),
Neuron::new(4, vec![], IDENTITY),
];
let network = NeuralNetwork::new(input_ids, output_ids, Some(4), neurons);
Expand Down
7 changes: 7 additions & 0 deletions src/neuroevolution_algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::vneuron::VNeuron;
use crate::discrete_vneuron::DiscreteVNeuron;
use crate::benchmarks::Benchmark;
use crate::neat::{Neat, Individual};
use crate::neural_network::NeuralNetwork;

pub trait NeuroevolutionAlgorithm {
fn optimization_step(&mut self, problem: &Benchmark);
Expand All @@ -23,6 +24,7 @@ pub enum Algorithm {
ContinuousBNA(VNeuron),
Neat(Neat),
NeatIndividual(Individual),
NeuralNetworek(NeuralNetwork),
}

impl std::fmt::Display for Algorithm {
Expand All @@ -33,6 +35,7 @@ impl std::fmt::Display for Algorithm {
Algorithm::DiscreteBNA(vneuron) => write!(f, "{}", vneuron),
Algorithm::ContinuousBNA(vneuron) => write!(f, "{}", vneuron),
Algorithm::Neat(neat) => write!(f, "{:?}", neat), // TODO: Implement Display for Neat
Algorithm::NeuralNetworek(network) => write!(f, "{:?}", network),
Algorithm::NeatIndividual(individual) => write!(f, "{:?}", individual),
}
}
Expand All @@ -46,6 +49,7 @@ impl NeuroevolutionAlgorithm for Algorithm {
Algorithm::DiscreteBNA(vneuron) => vneuron.optimize(problem, n_iters),
Algorithm::ContinuousBNA(vneuron) => vneuron.optimize(problem, n_iters),
Algorithm::Neat(neat) => neat.optimize(problem, n_iters),
Algorithm::NeuralNetworek(network) => network.optimize(problem, n_iters),
Algorithm::NeatIndividual(individual) => individual.optimize(problem, n_iters),
}
}
Expand All @@ -57,6 +61,7 @@ impl NeuroevolutionAlgorithm for Algorithm {
Algorithm::DiscreteBNA(vneuron) => vneuron.optimize_cmaes(problem),
Algorithm::ContinuousBNA(vneuron) => vneuron.optimize_cmaes(problem),
Algorithm::Neat(neat) => neat.optimize_cmaes(problem),
Algorithm::NeuralNetworek(network) => network.optimize_cmaes(problem),
Algorithm::NeatIndividual(individual) => individual.optimize_cmaes(problem),
}
}
Expand All @@ -68,6 +73,7 @@ impl NeuroevolutionAlgorithm for Algorithm {
Algorithm::DiscreteBNA(vneuron) => vneuron.evaluate(input),
Algorithm::ContinuousBNA(vneuron) => vneuron.evaluate(input),
Algorithm::Neat(neat) => neat.evaluate(input),
Algorithm::NeuralNetworek(network) => network.evaluate(input),
Algorithm::NeatIndividual(individual) => individual.evaluate(input),
}
}
Expand All @@ -79,6 +85,7 @@ impl NeuroevolutionAlgorithm for Algorithm {
Algorithm::DiscreteBNA(vneuron) => vneuron.optimization_step(problem),
Algorithm::ContinuousBNA(vneuron) => vneuron.optimization_step(problem),
Algorithm::Neat(neat) => neat.optimization_step(problem),
Algorithm::NeuralNetworek(network) => network.optimization_step(problem),
Algorithm::NeatIndividual(individual) => individual.optimization_step(problem),
}
}
Expand Down

0 comments on commit a8de79a

Please sign in to comment.