Skip to content

Commit

Permalink
added activation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
samyhaff committed Mar 29, 2024
1 parent 983185e commit 38642af
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 39 deletions.
79 changes: 50 additions & 29 deletions src/neat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ enum NodeType {
struct NodeGene {
id: u32,
layer: NodeType,
activation: ActivationFunction,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -56,8 +57,8 @@ struct Neat {
}

impl NodeGene {
fn new(id: u32, layer: NodeType) -> NodeGene {
NodeGene { id, layer }
fn new(id: u32, layer: NodeType, activation: ActivationFunction) -> NodeGene {
NodeGene { id, layer, activation }
}
}

Expand Down Expand Up @@ -119,7 +120,7 @@ impl Individual {
}

fn mutate_add_node(&mut self, history: &mut History) {
let new_node = NodeGene::new(history.nodes_nb + 1, NodeType::Hidden);
let new_node = NodeGene::new(history.nodes_nb + 1, NodeType::Hidden, SIGMOID ); // TODO get from config
history.nodes_nb += 1;
self.genome.add_node(new_node.clone());

Expand Down Expand Up @@ -233,28 +234,30 @@ impl Individual {
}

fn to_neural_network(&self) -> NeuralNetwork {
let input_ids = self.genome.nodes.iter().filter(|n| n.layer == NodeType::Input).map(|n| n.id).collect::<Vec<_>>();
let output_ids = self.genome.nodes.iter().filter(|n| n.layer == NodeType::Output).map(|n| n.id).collect::<Vec<_>>();
let (input_ids, input_activations): (Vec<u32>, Vec<ActivationFunction>) = self.genome.nodes.iter().filter(|n| n.layer == NodeType::Input).map(|n| (n.id, n.activation)).unzip();
let (output_ids, output_activations): (Vec<u32>, Vec<ActivationFunction>) = self.genome.nodes.iter().filter(|n| n.layer == NodeType::Output).map(|n| (n.id, n.activation)).unzip();
let mut input_activations = input_activations.iter();
let mut output_activations = output_activations.iter();

let mut neurons = Vec::new();

// Add input neurons
for input_id in input_ids.iter() {
let neuron = Neuron::new(*input_id, Vec::new());
let neuron = Neuron::new(*input_id, Vec::new(), *input_activations.next().unwrap());
neurons.push(neuron);
}

// Add hidden neurons
for node in self.genome.nodes.iter().filter(|n| n.layer == NodeType::Hidden) {
let inputs = self.genome.connections.iter().filter(|c| c.out_node == node.id && c.enabled).map(|c| NeuronInput::new(c.in_node, c.weight)).collect::<Vec<_>>();
let neuron = Neuron::new(node.id, inputs);
let neuron = Neuron::new(node.id, inputs, node.activation);
neurons.push(neuron);
}

// Add output neurons
for output_id in output_ids.iter() {
let inputs = self.genome.connections.iter().filter(|c| c.out_node == *output_id && c.enabled).map(|c| NeuronInput::new(c.in_node, c.weight)).collect::<Vec<_>>();
let neuron = Neuron::new(*output_id, inputs);
let neuron = Neuron::new(*output_id, inputs, *output_activations.next().unwrap());
neurons.push(neuron);
}

Expand All @@ -266,15 +269,14 @@ impl Individual {

// Add input and output nodes
for i in 1..=n_inputs {
let node = NodeGene::new(i, NodeType::Input);
let node = NodeGene::new(i, NodeType::Input, IDENTITY);
genome.add_node(node);
}
for i in 1..=n_outputs {
let node = NodeGene::new(i + n_inputs, NodeType::Output);
let node = NodeGene::new(i + n_inputs, NodeType::Output, SIGMOID); // TODO get from config
genome.add_node(node);
}


// Fully connect input nodes to output nodes
let normal = Normal::new(0.0, 1.0).unwrap();
for i in 1..=n_inputs {
Expand Down Expand Up @@ -317,12 +319,12 @@ mod tests {

#[test]
fn test_crossover() {
let node1 = NodeGene::new(1, NodeType::Input);
let node2 = NodeGene::new(2, NodeType::Input);
let node3 = NodeGene::new(3, NodeType::Input);
let node4 = NodeGene::new(4, NodeType::Output);
let node5 = NodeGene::new(5, NodeType::Hidden);
let node6 = NodeGene::new(6, NodeType::Hidden);
let node1 = NodeGene::new(1, NodeType::Input, SIGMOID);
let node2 = NodeGene::new(2, NodeType::Input, SIGMOID);
let node3 = NodeGene::new(3, NodeType::Input, SIGMOID);
let node4 = NodeGene::new(4, NodeType::Output, SIGMOID);
let node5 = NodeGene::new(5, NodeType::Hidden, SIGMOID);
let node6 = NodeGene::new(6, NodeType::Hidden, SIGMOID);

let conn_1_4 = ConnectionGene::new(1, 4, 1., true, 1);
let conn_2_4 = ConnectionGene::new(2, 4, 1., false, 2);
Expand Down Expand Up @@ -384,8 +386,8 @@ mod tests {

#[test]
fn test_mutate_add_node() {
let node1 = NodeGene::new(1, NodeType::Input);
let node2 = NodeGene::new(2, NodeType::Output);
let node1 = NodeGene::new(1, NodeType::Input, SIGMOID);
let node2 = NodeGene::new(2, NodeType::Output, SIGMOID);
let connection = ConnectionGene::new(1, 2, 0.5, true, 1);
let mut genome = Genome::new();
genome.add_node(node1);
Expand Down Expand Up @@ -414,8 +416,8 @@ mod tests {

#[test]
fn test_mutate_add_connection() {
let node1 = NodeGene::new(1, NodeType::Input);
let node2 = NodeGene::new(2, NodeType::Output);
let node1 = NodeGene::new(1, NodeType::Input, SIGMOID);
let node2 = NodeGene::new(2, NodeType::Output, SIGMOID);
let mut genome = Genome::new();
genome.add_node(node1);
genome.add_node(node2);
Expand All @@ -435,8 +437,8 @@ mod tests {

#[test]
fn test_mutate_add_connection_already_connected_nodes() {
let node1 = NodeGene::new(1, NodeType::Input);
let node2 = NodeGene::new(2, NodeType::Output);
let node1 = NodeGene::new(1, NodeType::Input, SIGMOID);
let node2 = NodeGene::new(2, NodeType::Output, SIGMOID);
let connection = ConnectionGene::new(1, 2, 0., true, 1);
let mut genome = Genome::new();
genome.add_node(node1);
Expand All @@ -454,9 +456,9 @@ mod tests {

#[test]
fn test_network_conversion() {
let node1 = NodeGene::new(1, NodeType::Input);
let node2 = NodeGene::new(2, NodeType::Input);
let node3 = NodeGene::new(3, NodeType::Output);
let node1 = NodeGene::new(1, NodeType::Input, IDENTITY);
let node2 = NodeGene::new(2, NodeType::Input, IDENTITY);
let node3 = NodeGene::new(3, NodeType::Output, IDENTITY);

let conn_1_3 = ConnectionGene::new(1, 3, 0.5, true, 1);
let conn_2_3 = ConnectionGene::new(2, 3, 0.5, true, 2);
Expand All @@ -479,9 +481,9 @@ mod tests {

#[test]
fn test_network_conversion_with_disabled_connection() {
let node1 = NodeGene::new(1, NodeType::Input);
let node2 = NodeGene::new(2, NodeType::Input);
let node3 = NodeGene::new(3, NodeType::Output);
let node1 = NodeGene::new(1, NodeType::Input, IDENTITY);
let node2 = NodeGene::new(2, NodeType::Input, IDENTITY);
let node3 = NodeGene::new(3, NodeType::Output, IDENTITY);

let conn_1_3 = ConnectionGene::new(1, 3, 0.5, true, 1);
let conn_2_3 = ConnectionGene::new(2, 3, 0.5, false, 2);
Expand Down Expand Up @@ -525,4 +527,23 @@ mod tests {
assert_eq!(node_ids, vec![1, 2, 3, 4, 5]);
assert_eq!(innovation_ids, vec![1, 2, 3, 4, 5, 6]);
}

#[test]
fn test_activation() {
let input_node = NodeGene::new(1, NodeType::Input, IDENTITY);
let output_node = NodeGene::new(2, NodeType::Output, SIGMOID);
let connection = ConnectionGene::new(1, 2, 1., true, 1);

let mut genome = Genome::new();
genome.add_node(input_node);
genome.add_node(output_node);
genome.add_connection(connection);

let individual = Individual::new(genome);
let network = individual.to_neural_network();

let inputs = vec![0.];
let outputs = network.feed_forward(inputs);
assert_eq!(outputs, vec![0.5]);
}
}
31 changes: 21 additions & 10 deletions src/neural_network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
//
use std::collections::HashMap;

pub type ActivationFunction = fn(f32) -> f32;

pub const SIGMOID: ActivationFunction = |x| 1. / (1. + (-x).exp());
pub const IDENTITY: ActivationFunction = |x| x;

pub struct NeuronInput {
input_id: u32,
weight: f32,
Expand All @@ -10,6 +15,7 @@ pub struct NeuronInput {
pub struct Neuron {
id: u32,
inputs: Vec<NeuronInput>,
activation: ActivationFunction,
}

pub struct NeuralNetwork {
Expand All @@ -28,10 +34,11 @@ impl NeuronInput {
}

impl Neuron {
pub fn new(id: u32, inputs: Vec<NeuronInput>) -> Neuron {
pub fn new(id: u32, inputs: Vec<NeuronInput>, activation: ActivationFunction) -> Neuron {
Neuron {
id,
inputs,
activation,
}
}
}
Expand All @@ -57,11 +64,16 @@ impl NeuralNetwork {
continue;
}

if neuron.inputs.is_empty() {
values.insert(neuron.id, 0.);
continue;
}

let mut sum = 0.;
for input in neuron.inputs.iter() {
sum += values.get(&input.input_id).unwrap() * input.weight;
}
values.insert(neuron.id, sum);
values.insert(neuron.id, (neuron.activation)(sum));
}

let mut outputs = Vec::<f32>::new();
Expand All @@ -74,7 +86,6 @@ impl NeuralNetwork {
}

#[cfg(test)]

mod tests {
use super::*;

Expand All @@ -83,9 +94,9 @@ mod tests {
let input_ids = vec![1, 2];
let output_ids = vec![3];
let neurons = vec![
Neuron::new(1, vec![]),
Neuron::new(2, vec![]),
Neuron::new(3, vec![NeuronInput::new(1, 0.5), NeuronInput::new(2, 0.5)]),
Neuron::new(1, vec![], IDENTITY),
Neuron::new(2, vec![], IDENTITY),
Neuron::new(3, vec![NeuronInput::new(1, 0.5), NeuronInput::new(2, 0.5)], IDENTITY),
];
let network = NeuralNetwork::new(input_ids, output_ids, neurons);

Expand All @@ -100,10 +111,10 @@ mod tests {
let input_ids = vec![1, 2];
let output_ids = vec![4];
let neurons = vec![
Neuron::new(1, vec![]),
Neuron::new(2, vec![]),
Neuron::new(3, vec![NeuronInput::new(1, 0.5), NeuronInput::new(2, 0.5)]),
Neuron::new(4, vec![NeuronInput::new(3, 0.5)]),
Neuron::new(1, vec![] , IDENTITY),
Neuron::new(2, vec![], IDENTITY),
Neuron::new(3, vec![NeuronInput::new(1, 0.5), NeuronInput::new(2, 0.5)], IDENTITY),
Neuron::new(4, vec![NeuronInput::new(3, 0.5)], IDENTITY),
];
let network = NeuralNetwork::new(input_ids, output_ids, neurons);

Expand Down

0 comments on commit 38642af

Please sign in to comment.