Skip to content

Commit

Permalink
implemented speciation
Browse files Browse the repository at this point in the history
  • Loading branch information
samyhaff committed Apr 6, 2024
1 parent 3a1fb15 commit 509067e
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 42 deletions.
12 changes: 8 additions & 4 deletions src/bin/neat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@ fn main() {
population_size: 150,
n_inputs: 2,
n_outputs: 1,
n_generations: 100,
n_generations: 3000,
evaluation_function: xor,
weights_mean: 0.,
weights_stddev: 4.,
perturbation_stddev: 1.,
weights_stddev: 6.,
perturbation_stddev: 3.,
survival_threshold: 0.25,
connection_mutation_rate: 0.3,
node_mutation_rate: 0.03,
node_mutation_rate: 0.1,
weight_mutation_rate: 0.8,
similarity_threshold: 3.,
excess_weight: 1.,
disjoint_weight: 1.,
matching_weight: 0.4,
};

let mut neat = Neat::new(config);
Expand Down
217 changes: 179 additions & 38 deletions src/neat.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#![allow(dead_code)] // TODO remove

use rand::prelude::*;
use rand_distr::Normal;
use crate::neural_network::*;
Expand Down Expand Up @@ -71,12 +69,21 @@ pub struct Config {
pub connection_mutation_rate: f32,
pub node_mutation_rate: f32,
pub weight_mutation_rate: f32,
pub similarity_threshold: f32,
pub excess_weight: f32,
pub disjoint_weight: f32,
pub matching_weight: f32,
}

struct Species {
representative: Individual,
members: Population,
}

pub struct Neat {
population: Population,
history: History,
config: Config,
species: Vec<Species>,
}

impl NodeGene {
Expand Down Expand Up @@ -339,6 +346,69 @@ impl Individual {
let network = self.to_neural_network();
network.feed_forward(input)
}

fn similarity(&self, other_individual: &Self, excess_weight: f32, disjoint_weight: f32, matching_weight: f32) -> f32 {
let disjoint_excess_count = |p1: &[ConnectionGene], p2: &[ConnectionGene]| -> (u32, u32) {
let mut disjoint = 0;
let mut excess = 0;
let mut iter1 = p1.iter();
let mut iter2 = p2.iter();

let mut conn1 = iter1.next();
let mut conn2 = iter2.next();

while let (Some(c1), Some(c2)) = (conn1, conn2) {
if c1.innovation == c2.innovation {
conn1 = iter1.next();
conn2 = iter2.next();
} else if c1.innovation < c2.innovation {
disjoint += 1;
conn1 = iter1.next();
} else {
disjoint += 1;
conn2 = iter2.next();
}
}

if let Some(c1) = conn1 {
excess = iter1.count() as u32 + 1;
}
else if let Some(c2) = conn2 {
excess = iter2.count() as u32 + 1;
}

(disjoint, excess)
};

let matching_weight_difference = |p1: &[ConnectionGene], p2: &[ConnectionGene]| -> f32 {
let mut sum = 0.;
let mut iter1 = p1.iter();
let mut iter2 = p2.iter();

let mut conn1 = iter1.next();
let mut conn2 = iter2.next();

while let (Some(c1), Some(c2)) = (conn1, conn2) {
if c1.innovation == c2.innovation {
sum += (c1.weight - c2.weight).abs();
conn1 = iter1.next();
conn2 = iter2.next();
} else if c1.innovation < c2.innovation {
conn1 = iter1.next();
} else {
conn2 = iter2.next();
}
}

sum
};

let (disjoint, excess) = disjoint_excess_count(&self.genome.connections, &other_individual.genome.connections);
let matching = matching_weight_difference(&self.genome.connections, &other_individual.genome.connections);
let max_genes_number = self.genome.connections.len().max(other_individual.genome.connections.len());

excess_weight * excess as f32 / max_genes_number as f32 + disjoint_weight * disjoint as f32 / max_genes_number as f32 + matching_weight * matching
}
}

impl History {
Expand All @@ -352,19 +422,40 @@ impl History {
}
}

impl Species {
fn new(representative: Individual) -> Species {
Species { representative, members: Vec::new(), }
}

fn add_member(&mut self, individual: Individual) {
self.members.push(individual);
}

fn set_representative(&mut self, representative: Individual) {
self.representative = representative;
}

fn get_random_member(&self) -> Option<&Individual> {
let mut rng = thread_rng();
self.members.choose(&mut rng)
}
}

impl Neat {
pub fn new(config: Config) -> Neat {
let weights_distribution = Normal::new(config.weights_mean, config.weights_stddev).unwrap();
let (history, population) = Self::get_initial_state(config.n_inputs, config.n_outputs, config.population_size as usize, &weights_distribution);

Neat {
population,
history,
species: Vec::new(),
history: History::new(config.n_inputs + config.n_outputs + 1, 0, Vec::new()),
config,
}
}

fn get_initial_state(n_inputs: u32, n_outputs: u32, population_size: usize, weights_distribution: &Normal<f32>) -> (History, Population) {
pub fn initialize(&mut self) {
let weights_distribution = Normal::new(self.config.weights_mean, self.config.weights_stddev).unwrap();
let n_inputs = self.config.n_inputs;
let n_outputs = self.config.n_outputs;
let population_size = self.config.population_size;

let get_initial_individual = |n_inputs, n_outputs, distributions: &Normal<f32>| {
let mut genome = Genome::new();

Expand Down Expand Up @@ -417,64 +508,110 @@ impl Neat {
}

let history = History { innovation, nodes_nb, mutations, generation: 0, };
self.history = history;

let population = (0..population_size).map(|_| get_initial_individual(n_inputs, n_outputs, &weights_distribution)).collect::<Vec<_>>();
self.update_species(population);
}

fn assign_to_species(&mut self, individual: Individual) {
for species in self.species.iter_mut() {
if individual.similarity(&species.representative, self.config.excess_weight, self.config.disjoint_weight, self.config.matching_weight) <= self.config.similarity_threshold {
species.add_member(individual);
return;
}
}

(history, population)
// create a new species and set the individual as the representative
let new_species = Species::new(individual);
self.species.push(new_species);
}

fn next_generation(&mut self) {
for individual in self.population.iter_mut() {
individual.fitness = (self.config.evaluation_function)(individual);
fn update_species(&mut self, population: Population) {
// clear species members
for species in self.species.iter_mut() {
species.members = Vec::new();
}

self.history.generation += 1;
// assign iindividuals to species
for individual in population {
self.assign_to_species(individual)
}

let mut new_population = Vec::new();
// remove empty species
self.species.retain(|species| !species.members.is_empty());

let mut sorted_population = self.population.clone();
sorted_population.sort_by(|a, b| b.fitness.partial_cmp(&a.fitness).unwrap());
// set species representatives
for species in self.species.iter_mut() {
let representative = species.get_random_member().unwrap().clone();
species.set_representative(representative);
}
}

let survival_cutoff = (self.config.population_size as f32 * self.config.survival_threshold) as usize;
let survivors = &sorted_population[..survival_cutoff];
new_population.extend_from_slice(survivors);
fn next_generation(&mut self) {
// iterate over individuals in all species
for species in self.species.iter_mut() {
for individual in species.members.iter_mut() {
individual.fitness = (self.config.evaluation_function)(individual);
}
}

self.history.generation += 1;

let mut new_population = Vec::new();
let mut rng = thread_rng();
let weights_distribution = Normal::new(self.config.weights_mean, self.config.weights_stddev).unwrap();
let perturbation_distribution = Normal::new(0., self.config.perturbation_stddev).unwrap();

while new_population.len() < self.config.population_size as usize {
let parent1 = survivors.choose(&mut rng).unwrap();
let parent2 = survivors.choose(&mut rng).unwrap();
for species in self.species.iter() {
let offsprings_nb = self.config.population_size / self.species.len() as u32;
let mut offsprings = Vec::new();
let mut sorted_members = species.members.clone();
sorted_members.sort_by(|a, b| b.fitness.partial_cmp(&a.fitness).unwrap());
let survival_cutoff = (species.members.len() as f32 * self.config.survival_threshold) as usize;
let survivors = &sorted_members[..=survival_cutoff];
offsprings.push(survivors[0].clone());

let mut child = Individual::crossover(parent1, parent2);
while offsprings.len() <= offsprings_nb as usize {
let parent1 = survivors.choose(&mut rng).unwrap();
let parent2 = survivors.choose(&mut rng).unwrap();

if rng.gen::<f32>() < self.config.connection_mutation_rate {
child.mutate_add_connection(&mut self.history, &weights_distribution);
}
let mut child = Individual::crossover(parent1, parent2);

if rng.gen::<f32>() < self.config.node_mutation_rate {
child.mutate_add_node(&mut self.history);
}
if rng.gen::<f32>() < self.config.connection_mutation_rate {
child.mutate_add_connection(&mut self.history, &weights_distribution);
}

if rng.gen::<f32>() < self.config.node_mutation_rate {
child.mutate_add_node(&mut self.history);
}

if rng.gen::<f32>() < self.config.weight_mutation_rate {
child.mutate_weights(&weights_distribution, &perturbation_distribution);
}

if rng.gen::<f32>() < self.config.weight_mutation_rate {
child.mutate_weights(&weights_distribution, &perturbation_distribution);
offsprings.push(child);
}

new_population.push(child);
new_population.extend_from_slice(&offsprings);
}

self.population = new_population;
self.update_species(new_population);
}

pub fn run(&mut self) {
self.initialize();

for i in 1..=self.config.n_generations {
self.next_generation();
}

for individual in self.population.iter_mut() {
individual.fitness = (self.config.evaluation_function)(individual);
println!("Individual fitness: {}", individual.fitness);
}
// for individual in self.population.iter_mut() {
// individual.fitness = (self.config.evaluation_function)(individual);
// println!("Individual fitness: {}", individual.fitness);
// println!("Number of nodes: {}", individual.genome.nodes.len());
// // println!("Number of connections: {}", individual.genome.connections.len());
// }
}
}

Expand Down Expand Up @@ -715,6 +852,10 @@ mod tests {
connection_mutation_rate: 0.1,
node_mutation_rate: 0.1,
weight_mutation_rate: 0.1,
similarity_threshold: 1.,
excess_weight: 1.,
disjoint_weight: 1.,
matching_weight: 1.,
};

let neat = Neat::new(config);
Expand Down

0 comments on commit 509067e

Please sign in to comment.