From 93b914eb5301a985a5f2068578da8566dae96095 Mon Sep 17 00:00:00 2001 From: samyhaff Date: Sun, 9 Jun 2024 14:07:39 +0200 Subject: [PATCH] experimenting with neat NN visualization --- src/gui.rs | 100 +++++++++++++++++++++++++++++++++++++++++- src/neat.rs | 2 +- src/neural_network.rs | 34 ++++++++++++++ 3 files changed, 133 insertions(+), 3 deletions(-) diff --git a/src/gui.rs b/src/gui.rs index f07dadc..8e434cc 100644 --- a/src/gui.rs +++ b/src/gui.rs @@ -1,5 +1,6 @@ use std::f64::consts::PI; use ggez::*; +use crate::neural_network::NeuralNetwork; use crate::neuroevolution_algorithm::*; use crate::benchmarks::Benchmark; @@ -267,6 +268,92 @@ impl State { Ok(()) } + + fn draw_neural_network(netork: &NeuralNetwork, mesh: &mut graphics::MeshBuilder) -> GameResult { + let input_ids = netork.get_input_ids(); + let output_ids = netork.get_output_ids(); + let hidden_ids = netork.get_hidden_ids(); + let bias_id = netork.get_bias_id(); + + let connections = netork.get_connection_ids(); + + for (from, to) in connections { + let (x1, y1) = match from { + id if input_ids.contains(&id) => (100.0, 200.0 + 50.0 * input_ids.iter().position(|&x| x == id).unwrap() as f32), + id if hidden_ids.contains(&id) => (300.0, 200.0 + 50.0 * hidden_ids.iter().position(|&x| x == id).unwrap() as f32), + id if output_ids.contains(&id) => (500.0, 200.0 + 50.0 * output_ids.iter().position(|&x| x == id).unwrap() as f32), + id if bias_id == Some(id) => (100.0, 150.0), + _ => panic!("Invalid neuron id"), + }; + + let (x2, y2) = match to { + id if input_ids.contains(&id) => (100.0, 200.0 + 50.0 * input_ids.iter().position(|&x| x == id).unwrap() as f32), + id if hidden_ids.contains(&id) => (300.0, 200.0 + 50.0 * hidden_ids.iter().position(|&x| x == id).unwrap() as f32), + id if output_ids.contains(&id) => (500.0, 200.0 + 50.0 * output_ids.iter().position(|&x| x == id).unwrap() as f32), + id if bias_id == Some(id) => (100.0, 150.0), + _ => panic!("Invalid neuron id"), + }; + + mesh.line( + &[ + mint::Point2{x: x1, y: y1}, + mint::Point2{x: x2, y: y2}, + ], + 2.0, + graphics::Color::BLACK, + )?; + } + + if let Some(_) = bias_id { + let x = 100.0; + let y = 150.0; + mesh.circle( + graphics::DrawMode::fill(), + mint::Point2{x, y}, + 20.0, + 0.1, + graphics::Color::GREEN, + )?; + } + + for (i, _) in input_ids.iter().enumerate() { + let x = 100.0; + let y = 200.0 + 50.0 * i as f32; + mesh.circle( + graphics::DrawMode::fill(), + mint::Point2{x, y}, + 20.0, + 0.1, + graphics::Color::BLUE, + )?; + } + + for (i, _) in hidden_ids.iter().enumerate() { + let x = 300.0; + let y = 200.0 + 50.0 * i as f32; + mesh.circle( + graphics::DrawMode::fill(), + mint::Point2{x, y}, + 20.0, + 0.1, + graphics::Color::RED, + )?; + } + + for (i, _) in output_ids.iter().enumerate() { + let x = 500.0; + let y = 200.0 + 50.0 * i as f32; + mesh.circle( + graphics::DrawMode::fill(), + mint::Point2{x, y}, + 20.0, + 0.1, + graphics::Color::BLACK, + )?; + } + + Ok(()) + } } impl ggez::event::EventHandler for State { @@ -293,8 +380,17 @@ impl ggez::event::EventHandler for State { let mut canvas = graphics::Canvas::from_frame(ctx, graphics::Color::WHITE); let mesh = &mut graphics::MeshBuilder::new(); - self.get_problem_points_mesh(mesh)?; - self.get_algorithm_mesh(mesh)?; + match &self.alg { + Algorithm::Neat(neat) => { + let best_individual = neat.get_best_individual(); + let network = best_individual.to_neural_network(); + State::draw_neural_network(&network, mesh)?; + } + _ => { + self.get_problem_points_mesh(mesh)?; + self.get_algorithm_mesh(mesh)?; + } + } let mut text = graphics::Text::new(format!("Iteration: {}\nFitness: {:.2}", self.iteration, self.problem.evaluate(&self.alg))); diff --git a/src/neat.rs b/src/neat.rs index 31af9ee..a8e1102 100644 --- a/src/neat.rs +++ b/src/neat.rs @@ -323,7 +323,7 @@ impl Individual { Individual::new(Genome { nodes, connections }) } - fn to_neural_network(&self) -> NeuralNetwork { + pub fn to_neural_network(&self) -> NeuralNetwork { let (input_ids, input_activations): (Vec, Vec) = self.genome.nodes.iter().filter(|n| n.layer == NodeType::Input).map(|n| (n.id, n.activation)).unzip(); let (output_ids, output_activations): (Vec, Vec) = self.genome.nodes.iter().filter(|n| n.layer == NodeType::Output).map(|n| (n.id, n.activation)).unzip(); let mut input_activations = input_activations.iter(); diff --git a/src/neural_network.rs b/src/neural_network.rs index 05f6a0b..0103685 100644 --- a/src/neural_network.rs +++ b/src/neural_network.rs @@ -169,6 +169,40 @@ impl NeuralNetwork { network } + + pub fn get_hidden_ids(&self) -> Vec { + let mut hidden_neurons_ids = Vec::new(); + for neuron in self.neurons.iter() { + if !self.input_ids.contains(&neuron.id) && !self.output_ids.contains(&neuron.id) && self.bias_id != Some(neuron.id) { + hidden_neurons_ids.push(neuron.id); + } + } + + hidden_neurons_ids + } + + pub fn get_input_ids(&self) -> &Vec { + &self.input_ids + } + + pub fn get_output_ids(&self) -> &Vec { + &self.output_ids + } + + pub fn get_bias_id(&self) -> Option { + self.bias_id + } + + pub fn get_connection_ids(&self) -> Vec<(u32, u32)> { + let mut connection_ids = Vec::new(); + for neuron in self.neurons.iter() { + for connection in &neuron.inputs { + connection_ids.push((neuron.id, connection.input_id)); + } + } + + connection_ids + } } impl NeuroevolutionAlgorithm for NeuralNetwork {