From b68e9d372e3335fa9e6fdfde98c9aa393c54836a Mon Sep 17 00:00:00 2001 From: Sharp Hall Date: Fri, 30 Aug 2024 16:59:51 -0400 Subject: [PATCH] Fix topological sort --- Cargo.lock | 1 + Cargo.toml | 1 + src/synth.rs | 198 +++++++++++++++++++++++++++++++++++++++++++-------- 3 files changed, 172 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8f9a156..55ce7e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2782,6 +2782,7 @@ dependencies = [ "hound", "itertools 0.13.0", "log", + "rand", "rfd", "uuid", "wasm-bindgen-futures", diff --git a/Cargo.toml b/Cargo.toml index ee59122..621c1b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ itertools = "0.13.0" by_address = "1.2.1" rfd = "0.14.1" hound = "3.5.1" +rand = "0.8.5" # native: diff --git a/src/synth.rs b/src/synth.rs index f9df572..d7cdbc1 100644 --- a/src/synth.rs +++ b/src/synth.rs @@ -98,6 +98,27 @@ pub fn ui_dirty(plan: &Vec) -> bool { plan.iter().any(|module| module.read().unwrap().ui_dirty()) } +fn is_loop( + module: &SharedSynthModule, + edges: &HashMap, Vec>>, +) -> Option> { + let mut to_search: Vec> = vec![ByAddress(module.clone())]; + let mut to_add: Vec> = vec![]; + let mut visited: HashSet> = HashSet::new(); + while let Some(current_module) = to_search.iter().filter(|m| visited.get(m).is_none()).next() { + visited.insert(current_module.clone()); + for dependency in edges.get(¤t_module.clone()).unwrap() { + if dependency.clone() == ByAddress(module.clone()) { + println!("cycle detected"); + return Some(current_module.clone()); + } + to_add.push(dependency.clone()); + } + to_search.append(&mut to_add); + } + None +} + pub fn plan_execution( output: SharedSynthModule, all_modules: &Vec, @@ -105,11 +126,12 @@ pub fn plan_execution( ) -> () { // topological sort of a graph with cycles -- first we need to break cycles let mut edges: HashMap, Vec>> = - HashMap::new(); + HashMap::new(); // K: sink, V: sources let mut visited: HashSet> = HashSet::new(); let mut to_search = all_modules.clone(); + to_search.push(output.clone()); loop { - // depth first search to break cycles + // create all edges let module = to_search.pop(); if module.is_none() { break; @@ -120,14 +142,12 @@ pub fn plan_execution( } let unlocked = module.read().unwrap(); edges.insert( - // store edges, but only to nodes which have not been visited + // store edges ByAddress(module.clone()), get_inputs(&*unlocked) .into_iter() .filter(|i| i.is_some()) .map(|i| i.unwrap().0) - .filter(|m| visited.get(&ByAddress(m.clone())).is_none()) // don't create edges to - // nodes visited .map(|m| { to_search.push(m.clone()); ByAddress(m) @@ -135,32 +155,55 @@ pub fn plan_execution( .collect(), ); } - let to_search = all_modules.clone(); + let mut to_search = all_modules.clone(); plan.clear(); visited.clear(); - loop { - // find leaves first, then search for nodes for which children have already been visited - if let Some(node) = to_search - .iter() - .map(|m| m.clone()) - .filter(|m| { - edges - .get(&ByAddress(m.clone())) - .unwrap() - .into_iter() - .filter(|d| !visited.contains(d)) - .collect::>() - .len() - == 0 - }) - .filter(|m| !visited.contains(&ByAddress(m.clone()))) - .next() - { - visited.insert(ByAddress(node.clone())); - plan.push(node); - } else { - break; + to_search.push(output.clone()); + // remove cycles + while let Some(module) = to_search.pop() { + if !visited.insert(ByAddress(module.clone())) { + continue; + } + for dependency in edges.get(&ByAddress(module.clone())).unwrap().iter() { + to_search.push((**dependency).clone()); } + while let Some(from) = is_loop(&module, &edges) { + let unlocked = module.read().unwrap(); + println!("{}", unlocked.get_name()); + while let Some(idx) = edges + .get(&from) + .unwrap() + .iter() + .enumerate() + .filter(|(_, m)| ByAddress(module.clone()) == **m) + .map(|(idx, _)| idx) + .next() + { + let dependencies = edges.get_mut(&from).unwrap(); + dependencies.remove(idx); + } + } + } + visited.clear(); + let to_search = all_modules.clone(); + // find leaves first, then search for nodes for which children have already been visited + // find next node with no dependencies that haven't been visited + while let Some(node) = to_search + .iter() + .map(|m| m.clone()) + .filter(|m| !visited.contains(&ByAddress(m.clone()))) + .filter(|m| { + !edges + .get(&ByAddress(m.clone())) + .unwrap() + .iter() + .any(|d| !visited.contains(d)) // any will return true when there are unvisited + // dependencies + }) + .next() + { + visited.insert(ByAddress(node.clone())); + plan.push(node); } } @@ -283,3 +326,102 @@ pub fn get_catalog() -> Vec<(String, Box SharedSynthModu ), ] } + +#[cfg(test)] +mod tests { + use super::*; + use rand::seq::SliceRandom; + use std::collections::HashMap; + + fn connect(src: SharedSynthModule, sink: SharedSynthModule) { + let mut unlocked_sink = sink.write().unwrap(); + let unconnected_idx = get_inputs(&*unlocked_sink) + .iter() + .enumerate() + .filter(|(_idx, input)| input.is_none()) + .map(|(idx, _)| idx) + .next() + .unwrap(); + unlocked_sink + .set_input(unconnected_idx as u8, src, 0) + .unwrap(); + } + + #[test] + fn topographical_sort() { + // 0 -> 1 -> 2 -> 3 -> o + // \----> 4 -----^ + // 5<->6^ + let ac = AudioConfig { + buffer_size: 64, + sample_rate: 44100, + channels: 2, + }; + let mut rng = rand::thread_rng(); + let create_mod = || Arc::new(RwLock::new(mixer::MonoMixerModule::new(&ac))); + let out = Arc::new(RwLock::new(output::OutputModule::new(&ac))); + let modules: Vec = + (0..7).map(|_| create_mod() as SharedSynthModule).collect(); + connect(modules[0].clone(), modules[1].clone()); + connect(modules[1].clone(), modules[2].clone()); + connect(modules[2].clone(), modules[3].clone()); + connect(modules[3].clone(), out.clone()); + connect(modules[0].clone(), modules[4].clone()); + connect(modules[4].clone(), modules[3].clone()); + connect(modules[6].clone(), modules[4].clone()); + connect(modules[5].clone(), modules[6].clone()); + connect(modules[6].clone(), modules[5].clone()); + for _ in 0..1000 { + let mut indexes: HashMap, usize> = HashMap::new(); + let mut list: Vec = Vec::new(); + let mut plan: Vec = Vec::new(); + list.append(&mut modules.clone()); + list.push(out.clone()); + list.shuffle(&mut rng); + plan_execution(out.clone(), &list, &mut plan); + println!("---"); + for (idx, module) in plan.iter().enumerate() { + indexes.insert(ByAddress(module.clone()), idx); + } + for (idx, mapping) in modules + .iter() + .map(|m| indexes.get(&ByAddress(m.clone())).unwrap()) + .enumerate() + { + println!("{} -> {}", idx, mapping); + } + println!("o -> {}", indexes.get(&ByAddress(out.clone())).unwrap()); + assert!( + indexes.get(&ByAddress(modules[0].clone())) + < indexes.get(&ByAddress(modules[1].clone())) + ); + assert!( + indexes.get(&ByAddress(modules[1].clone())) + < indexes.get(&ByAddress(modules[2].clone())) + ); + assert!( + indexes.get(&ByAddress(modules[2].clone())) + < indexes.get(&ByAddress(modules[3].clone())) + ); + assert!( + indexes.get(&ByAddress(modules[3].clone())) < indexes.get(&ByAddress(out.clone())) + ); + assert!( + indexes.get(&ByAddress(modules[0].clone())) + < indexes.get(&ByAddress(modules[4].clone())) + ); + assert!( + indexes.get(&ByAddress(modules[4].clone())) + < indexes.get(&ByAddress(modules[3].clone())) + ); + assert!( + indexes.get(&ByAddress(modules[6].clone())) + < indexes.get(&ByAddress(modules[4].clone())) + ); + assert!( + indexes.get(&ByAddress(modules[5].clone())) + < indexes.get(&ByAddress(modules[6].clone())) + ); + } + } +}