Skip to content

Commit

Permalink
experimenting with neat NN visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
samyhaff committed Jun 9, 2024
1 parent 93b5623 commit 93b914e
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 3 deletions.
100 changes: 98 additions & 2 deletions src/gui.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::f64::consts::PI;
use ggez::*;
use crate::neural_network::NeuralNetwork;
use crate::neuroevolution_algorithm::*;
use crate::benchmarks::Benchmark;

Expand Down Expand Up @@ -267,6 +268,92 @@ impl State {

Ok(())
}

fn draw_neural_network(netork: &NeuralNetwork, mesh: &mut graphics::MeshBuilder) -> GameResult {
let input_ids = netork.get_input_ids();
let output_ids = netork.get_output_ids();
let hidden_ids = netork.get_hidden_ids();
let bias_id = netork.get_bias_id();

let connections = netork.get_connection_ids();

for (from, to) in connections {
let (x1, y1) = match from {
id if input_ids.contains(&id) => (100.0, 200.0 + 50.0 * input_ids.iter().position(|&x| x == id).unwrap() as f32),
id if hidden_ids.contains(&id) => (300.0, 200.0 + 50.0 * hidden_ids.iter().position(|&x| x == id).unwrap() as f32),
id if output_ids.contains(&id) => (500.0, 200.0 + 50.0 * output_ids.iter().position(|&x| x == id).unwrap() as f32),
id if bias_id == Some(id) => (100.0, 150.0),
_ => panic!("Invalid neuron id"),
};

let (x2, y2) = match to {
id if input_ids.contains(&id) => (100.0, 200.0 + 50.0 * input_ids.iter().position(|&x| x == id).unwrap() as f32),
id if hidden_ids.contains(&id) => (300.0, 200.0 + 50.0 * hidden_ids.iter().position(|&x| x == id).unwrap() as f32),
id if output_ids.contains(&id) => (500.0, 200.0 + 50.0 * output_ids.iter().position(|&x| x == id).unwrap() as f32),
id if bias_id == Some(id) => (100.0, 150.0),
_ => panic!("Invalid neuron id"),
};

mesh.line(
&[
mint::Point2{x: x1, y: y1},
mint::Point2{x: x2, y: y2},
],
2.0,
graphics::Color::BLACK,
)?;
}

if let Some(_) = bias_id {
let x = 100.0;
let y = 150.0;
mesh.circle(
graphics::DrawMode::fill(),
mint::Point2{x, y},
20.0,
0.1,
graphics::Color::GREEN,
)?;
}

for (i, _) in input_ids.iter().enumerate() {
let x = 100.0;
let y = 200.0 + 50.0 * i as f32;
mesh.circle(
graphics::DrawMode::fill(),
mint::Point2{x, y},
20.0,
0.1,
graphics::Color::BLUE,
)?;
}

for (i, _) in hidden_ids.iter().enumerate() {
let x = 300.0;
let y = 200.0 + 50.0 * i as f32;
mesh.circle(
graphics::DrawMode::fill(),
mint::Point2{x, y},
20.0,
0.1,
graphics::Color::RED,
)?;
}

for (i, _) in output_ids.iter().enumerate() {
let x = 500.0;
let y = 200.0 + 50.0 * i as f32;
mesh.circle(
graphics::DrawMode::fill(),
mint::Point2{x, y},
20.0,
0.1,
graphics::Color::BLACK,
)?;
}

Ok(())
}
}

impl ggez::event::EventHandler<GameError> for State {
Expand All @@ -293,8 +380,17 @@ impl ggez::event::EventHandler<GameError> for State {
let mut canvas = graphics::Canvas::from_frame(ctx, graphics::Color::WHITE);
let mesh = &mut graphics::MeshBuilder::new();

self.get_problem_points_mesh(mesh)?;
self.get_algorithm_mesh(mesh)?;
match &self.alg {
Algorithm::Neat(neat) => {
let best_individual = neat.get_best_individual();
let network = best_individual.to_neural_network();
State::draw_neural_network(&network, mesh)?;
}
_ => {
self.get_problem_points_mesh(mesh)?;
self.get_algorithm_mesh(mesh)?;
}
}

let mut text = graphics::Text::new(format!("Iteration: {}\nFitness: {:.2}", self.iteration, self.problem.evaluate(&self.alg)));

Expand Down
2 changes: 1 addition & 1 deletion src/neat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ impl Individual {
Individual::new(Genome { nodes, connections })
}

fn to_neural_network(&self) -> NeuralNetwork {
pub fn to_neural_network(&self) -> NeuralNetwork {
let (input_ids, input_activations): (Vec<u32>, Vec<ActivationFunction>) = self.genome.nodes.iter().filter(|n| n.layer == NodeType::Input).map(|n| (n.id, n.activation)).unzip();
let (output_ids, output_activations): (Vec<u32>, Vec<ActivationFunction>) = self.genome.nodes.iter().filter(|n| n.layer == NodeType::Output).map(|n| (n.id, n.activation)).unzip();
let mut input_activations = input_activations.iter();
Expand Down
34 changes: 34 additions & 0 deletions src/neural_network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,40 @@ impl NeuralNetwork {

network
}

pub fn get_hidden_ids(&self) -> Vec<u32> {
let mut hidden_neurons_ids = Vec::new();
for neuron in self.neurons.iter() {
if !self.input_ids.contains(&neuron.id) && !self.output_ids.contains(&neuron.id) && self.bias_id != Some(neuron.id) {
hidden_neurons_ids.push(neuron.id);
}
}

hidden_neurons_ids
}

pub fn get_input_ids(&self) -> &Vec<u32> {
&self.input_ids
}

pub fn get_output_ids(&self) -> &Vec<u32> {
&self.output_ids
}

pub fn get_bias_id(&self) -> Option<u32> {
self.bias_id
}

pub fn get_connection_ids(&self) -> Vec<(u32, u32)> {
let mut connection_ids = Vec::new();
for neuron in self.neurons.iter() {
for connection in &neuron.inputs {
connection_ids.push((neuron.id, connection.input_id));
}
}

connection_ids
}
}

impl NeuroevolutionAlgorithm for NeuralNetwork {
Expand Down

0 comments on commit 93b914e

Please sign in to comment.