From 509067e0cd4ba4d48fac36dfe2b368347c9b30c2 Mon Sep 17 00:00:00 2001 From: samyhaff Date: Sat, 6 Apr 2024 14:12:27 +0200 Subject: [PATCH] implemented speciation --- src/bin/neat.rs | 12 ++- src/neat.rs | 217 +++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 187 insertions(+), 42 deletions(-) diff --git a/src/bin/neat.rs b/src/bin/neat.rs index b346041..692be18 100644 --- a/src/bin/neat.rs +++ b/src/bin/neat.rs @@ -18,15 +18,19 @@ fn main() { population_size: 150, n_inputs: 2, n_outputs: 1, - n_generations: 100, + n_generations: 3000, evaluation_function: xor, weights_mean: 0., - weights_stddev: 4., - perturbation_stddev: 1., + weights_stddev: 6., + perturbation_stddev: 3., survival_threshold: 0.25, connection_mutation_rate: 0.3, - node_mutation_rate: 0.03, + node_mutation_rate: 0.1, weight_mutation_rate: 0.8, + similarity_threshold: 3., + excess_weight: 1., + disjoint_weight: 1., + matching_weight: 0.4, }; let mut neat = Neat::new(config); diff --git a/src/neat.rs b/src/neat.rs index 8b1c9f8..d49d4bf 100644 --- a/src/neat.rs +++ b/src/neat.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] // TODO remove - use rand::prelude::*; use rand_distr::Normal; use crate::neural_network::*; @@ -71,12 +69,21 @@ pub struct Config { pub connection_mutation_rate: f32, pub node_mutation_rate: f32, pub weight_mutation_rate: f32, + pub similarity_threshold: f32, + pub excess_weight: f32, + pub disjoint_weight: f32, + pub matching_weight: f32, +} + +struct Species { + representative: Individual, + members: Population, } pub struct Neat { - population: Population, history: History, config: Config, + species: Vec, } impl NodeGene { @@ -339,6 +346,69 @@ impl Individual { let network = self.to_neural_network(); network.feed_forward(input) } + + fn similarity(&self, other_individual: &Self, excess_weight: f32, disjoint_weight: f32, matching_weight: f32) -> f32 { + let disjoint_excess_count = |p1: &[ConnectionGene], p2: &[ConnectionGene]| -> (u32, u32) { + let mut disjoint = 0; + let mut excess = 0; + let mut iter1 = p1.iter(); + let mut iter2 = p2.iter(); + + let mut conn1 = iter1.next(); + let mut conn2 = iter2.next(); + + while let (Some(c1), Some(c2)) = (conn1, conn2) { + if c1.innovation == c2.innovation { + conn1 = iter1.next(); + conn2 = iter2.next(); + } else if c1.innovation < c2.innovation { + disjoint += 1; + conn1 = iter1.next(); + } else { + disjoint += 1; + conn2 = iter2.next(); + } + } + + if let Some(c1) = conn1 { + excess = iter1.count() as u32 + 1; + } + else if let Some(c2) = conn2 { + excess = iter2.count() as u32 + 1; + } + + (disjoint, excess) + }; + + let matching_weight_difference = |p1: &[ConnectionGene], p2: &[ConnectionGene]| -> f32 { + let mut sum = 0.; + let mut iter1 = p1.iter(); + let mut iter2 = p2.iter(); + + let mut conn1 = iter1.next(); + let mut conn2 = iter2.next(); + + while let (Some(c1), Some(c2)) = (conn1, conn2) { + if c1.innovation == c2.innovation { + sum += (c1.weight - c2.weight).abs(); + conn1 = iter1.next(); + conn2 = iter2.next(); + } else if c1.innovation < c2.innovation { + conn1 = iter1.next(); + } else { + conn2 = iter2.next(); + } + } + + sum + }; + + let (disjoint, excess) = disjoint_excess_count(&self.genome.connections, &other_individual.genome.connections); + let matching = matching_weight_difference(&self.genome.connections, &other_individual.genome.connections); + let max_genes_number = self.genome.connections.len().max(other_individual.genome.connections.len()); + + excess_weight * excess as f32 / max_genes_number as f32 + disjoint_weight * disjoint as f32 / max_genes_number as f32 + matching_weight * matching + } } impl History { @@ -352,19 +422,40 @@ impl History { } } +impl Species { + fn new(representative: Individual) -> Species { + Species { representative, members: Vec::new(), } + } + + fn add_member(&mut self, individual: Individual) { + self.members.push(individual); + } + + fn set_representative(&mut self, representative: Individual) { + self.representative = representative; + } + + fn get_random_member(&self) -> Option<&Individual> { + let mut rng = thread_rng(); + self.members.choose(&mut rng) + } +} + impl Neat { pub fn new(config: Config) -> Neat { - let weights_distribution = Normal::new(config.weights_mean, config.weights_stddev).unwrap(); - let (history, population) = Self::get_initial_state(config.n_inputs, config.n_outputs, config.population_size as usize, &weights_distribution); - Neat { - population, - history, + species: Vec::new(), + history: History::new(config.n_inputs + config.n_outputs + 1, 0, Vec::new()), config, } } - fn get_initial_state(n_inputs: u32, n_outputs: u32, population_size: usize, weights_distribution: &Normal) -> (History, Population) { + pub fn initialize(&mut self) { + let weights_distribution = Normal::new(self.config.weights_mean, self.config.weights_stddev).unwrap(); + let n_inputs = self.config.n_inputs; + let n_outputs = self.config.n_outputs; + let population_size = self.config.population_size; + let get_initial_individual = |n_inputs, n_outputs, distributions: &Normal| { let mut genome = Genome::new(); @@ -417,64 +508,110 @@ impl Neat { } let history = History { innovation, nodes_nb, mutations, generation: 0, }; + self.history = history; + let population = (0..population_size).map(|_| get_initial_individual(n_inputs, n_outputs, &weights_distribution)).collect::>(); + self.update_species(population); + } + + fn assign_to_species(&mut self, individual: Individual) { + for species in self.species.iter_mut() { + if individual.similarity(&species.representative, self.config.excess_weight, self.config.disjoint_weight, self.config.matching_weight) <= self.config.similarity_threshold { + species.add_member(individual); + return; + } + } - (history, population) + // create a new species and set the individual as the representative + let new_species = Species::new(individual); + self.species.push(new_species); } - fn next_generation(&mut self) { - for individual in self.population.iter_mut() { - individual.fitness = (self.config.evaluation_function)(individual); + fn update_species(&mut self, population: Population) { + // clear species members + for species in self.species.iter_mut() { + species.members = Vec::new(); } - self.history.generation += 1; + // assign iindividuals to species + for individual in population { + self.assign_to_species(individual) + } - let mut new_population = Vec::new(); + // remove empty species + self.species.retain(|species| !species.members.is_empty()); - let mut sorted_population = self.population.clone(); - sorted_population.sort_by(|a, b| b.fitness.partial_cmp(&a.fitness).unwrap()); + // set species representatives + for species in self.species.iter_mut() { + let representative = species.get_random_member().unwrap().clone(); + species.set_representative(representative); + } + } - let survival_cutoff = (self.config.population_size as f32 * self.config.survival_threshold) as usize; - let survivors = &sorted_population[..survival_cutoff]; - new_population.extend_from_slice(survivors); + fn next_generation(&mut self) { + // iterate over individuals in all species + for species in self.species.iter_mut() { + for individual in species.members.iter_mut() { + individual.fitness = (self.config.evaluation_function)(individual); + } + } + + self.history.generation += 1; + let mut new_population = Vec::new(); let mut rng = thread_rng(); let weights_distribution = Normal::new(self.config.weights_mean, self.config.weights_stddev).unwrap(); let perturbation_distribution = Normal::new(0., self.config.perturbation_stddev).unwrap(); - while new_population.len() < self.config.population_size as usize { - let parent1 = survivors.choose(&mut rng).unwrap(); - let parent2 = survivors.choose(&mut rng).unwrap(); + for species in self.species.iter() { + let offsprings_nb = self.config.population_size / self.species.len() as u32; + let mut offsprings = Vec::new(); + let mut sorted_members = species.members.clone(); + sorted_members.sort_by(|a, b| b.fitness.partial_cmp(&a.fitness).unwrap()); + let survival_cutoff = (species.members.len() as f32 * self.config.survival_threshold) as usize; + let survivors = &sorted_members[..=survival_cutoff]; + offsprings.push(survivors[0].clone()); - let mut child = Individual::crossover(parent1, parent2); + while offsprings.len() <= offsprings_nb as usize { + let parent1 = survivors.choose(&mut rng).unwrap(); + let parent2 = survivors.choose(&mut rng).unwrap(); - if rng.gen::() < self.config.connection_mutation_rate { - child.mutate_add_connection(&mut self.history, &weights_distribution); - } + let mut child = Individual::crossover(parent1, parent2); - if rng.gen::() < self.config.node_mutation_rate { - child.mutate_add_node(&mut self.history); - } + if rng.gen::() < self.config.connection_mutation_rate { + child.mutate_add_connection(&mut self.history, &weights_distribution); + } + + if rng.gen::() < self.config.node_mutation_rate { + child.mutate_add_node(&mut self.history); + } + + if rng.gen::() < self.config.weight_mutation_rate { + child.mutate_weights(&weights_distribution, &perturbation_distribution); + } - if rng.gen::() < self.config.weight_mutation_rate { - child.mutate_weights(&weights_distribution, &perturbation_distribution); + offsprings.push(child); } - new_population.push(child); + new_population.extend_from_slice(&offsprings); } - self.population = new_population; + self.update_species(new_population); } pub fn run(&mut self) { + self.initialize(); + for i in 1..=self.config.n_generations { self.next_generation(); } - for individual in self.population.iter_mut() { - individual.fitness = (self.config.evaluation_function)(individual); - println!("Individual fitness: {}", individual.fitness); - } + // for individual in self.population.iter_mut() { + // individual.fitness = (self.config.evaluation_function)(individual); + // println!("Individual fitness: {}", individual.fitness); + // println!("Number of nodes: {}", individual.genome.nodes.len()); + // // println!("Number of connections: {}", individual.genome.connections.len()); + // } } } @@ -715,6 +852,10 @@ mod tests { connection_mutation_rate: 0.1, node_mutation_rate: 0.1, weight_mutation_rate: 0.1, + similarity_threshold: 1., + excess_weight: 1., + disjoint_weight: 1., + matching_weight: 1., }; let neat = Neat::new(config);