Skip to content

Commit

Permalink
Fix topological sort
Browse files Browse the repository at this point in the history
  • Loading branch information
sharph committed Aug 30, 2024
1 parent aa4b079 commit b68e9d3
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 28 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ itertools = "0.13.0"
by_address = "1.2.1"
rfd = "0.14.1"
hound = "3.5.1"
rand = "0.8.5"


# native:
Expand Down
198 changes: 170 additions & 28 deletions src/synth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,40 @@ pub fn ui_dirty(plan: &Vec<SharedSynthModule>) -> bool {
plan.iter().any(|module| module.read().unwrap().ui_dirty())
}

fn is_loop(
module: &SharedSynthModule,
edges: &HashMap<ByAddress<SharedSynthModule>, Vec<ByAddress<SharedSynthModule>>>,
) -> Option<ByAddress<SharedSynthModule>> {
let mut to_search: Vec<ByAddress<SharedSynthModule>> = vec![ByAddress(module.clone())];
let mut to_add: Vec<ByAddress<SharedSynthModule>> = vec![];
let mut visited: HashSet<ByAddress<SharedSynthModule>> = HashSet::new();
while let Some(current_module) = to_search.iter().filter(|m| visited.get(m).is_none()).next() {
visited.insert(current_module.clone());
for dependency in edges.get(&current_module.clone()).unwrap() {
if dependency.clone() == ByAddress(module.clone()) {
println!("cycle detected");
return Some(current_module.clone());
}
to_add.push(dependency.clone());
}
to_search.append(&mut to_add);
}
None
}

pub fn plan_execution(
output: SharedSynthModule,
all_modules: &Vec<SharedSynthModule>,
plan: &mut Vec<SharedSynthModule>,
) -> () {
// topological sort of a graph with cycles -- first we need to break cycles
let mut edges: HashMap<ByAddress<SharedSynthModule>, Vec<ByAddress<SharedSynthModule>>> =
HashMap::new();
HashMap::new(); // K: sink, V: sources
let mut visited: HashSet<ByAddress<SharedSynthModule>> = HashSet::new();
let mut to_search = all_modules.clone();
to_search.push(output.clone());
loop {
// depth first search to break cycles
// create all edges
let module = to_search.pop();
if module.is_none() {
break;
Expand All @@ -120,47 +142,68 @@ pub fn plan_execution(
}
let unlocked = module.read().unwrap();
edges.insert(
// store edges, but only to nodes which have not been visited
// store edges
ByAddress(module.clone()),
get_inputs(&*unlocked)
.into_iter()
.filter(|i| i.is_some())
.map(|i| i.unwrap().0)
.filter(|m| visited.get(&ByAddress(m.clone())).is_none()) // don't create edges to
// nodes visited
.map(|m| {
to_search.push(m.clone());
ByAddress(m)
})
.collect(),
);
}
let to_search = all_modules.clone();
let mut to_search = all_modules.clone();
plan.clear();
visited.clear();
loop {
// find leaves first, then search for nodes for which children have already been visited
if let Some(node) = to_search
.iter()
.map(|m| m.clone())
.filter(|m| {
edges
.get(&ByAddress(m.clone()))
.unwrap()
.into_iter()
.filter(|d| !visited.contains(d))
.collect::<Vec<_>>()
.len()
== 0
})
.filter(|m| !visited.contains(&ByAddress(m.clone())))
.next()
{
visited.insert(ByAddress(node.clone()));
plan.push(node);
} else {
break;
to_search.push(output.clone());
// remove cycles
while let Some(module) = to_search.pop() {
if !visited.insert(ByAddress(module.clone())) {
continue;
}
for dependency in edges.get(&ByAddress(module.clone())).unwrap().iter() {
to_search.push((**dependency).clone());
}
while let Some(from) = is_loop(&module, &edges) {
let unlocked = module.read().unwrap();
println!("{}", unlocked.get_name());
while let Some(idx) = edges
.get(&from)
.unwrap()
.iter()
.enumerate()
.filter(|(_, m)| ByAddress(module.clone()) == **m)
.map(|(idx, _)| idx)
.next()
{
let dependencies = edges.get_mut(&from).unwrap();
dependencies.remove(idx);
}
}
}
visited.clear();
let to_search = all_modules.clone();
// find leaves first, then search for nodes for which children have already been visited
// find next node with no dependencies that haven't been visited
while let Some(node) = to_search
.iter()
.map(|m| m.clone())
.filter(|m| !visited.contains(&ByAddress(m.clone())))
.filter(|m| {
!edges
.get(&ByAddress(m.clone()))
.unwrap()
.iter()
.any(|d| !visited.contains(d)) // any will return true when there are unvisited
// dependencies
})
.next()
{
visited.insert(ByAddress(node.clone()));
plan.push(node);
}
}

Expand Down Expand Up @@ -283,3 +326,102 @@ pub fn get_catalog() -> Vec<(String, Box<dyn Fn(&AudioConfig) -> SharedSynthModu
),
]
}

#[cfg(test)]
mod tests {
use super::*;
use rand::seq::SliceRandom;
use std::collections::HashMap;

fn connect(src: SharedSynthModule, sink: SharedSynthModule) {
let mut unlocked_sink = sink.write().unwrap();
let unconnected_idx = get_inputs(&*unlocked_sink)
.iter()
.enumerate()
.filter(|(_idx, input)| input.is_none())
.map(|(idx, _)| idx)
.next()
.unwrap();
unlocked_sink
.set_input(unconnected_idx as u8, src, 0)
.unwrap();
}

#[test]
fn topographical_sort() {
// 0 -> 1 -> 2 -> 3 -> o
// \----> 4 -----^
// 5<->6^
let ac = AudioConfig {
buffer_size: 64,
sample_rate: 44100,
channels: 2,
};
let mut rng = rand::thread_rng();
let create_mod = || Arc::new(RwLock::new(mixer::MonoMixerModule::new(&ac)));
let out = Arc::new(RwLock::new(output::OutputModule::new(&ac)));
let modules: Vec<SharedSynthModule> =
(0..7).map(|_| create_mod() as SharedSynthModule).collect();
connect(modules[0].clone(), modules[1].clone());
connect(modules[1].clone(), modules[2].clone());
connect(modules[2].clone(), modules[3].clone());
connect(modules[3].clone(), out.clone());
connect(modules[0].clone(), modules[4].clone());
connect(modules[4].clone(), modules[3].clone());
connect(modules[6].clone(), modules[4].clone());
connect(modules[5].clone(), modules[6].clone());
connect(modules[6].clone(), modules[5].clone());
for _ in 0..1000 {
let mut indexes: HashMap<ByAddress<SharedSynthModule>, usize> = HashMap::new();
let mut list: Vec<SharedSynthModule> = Vec::new();
let mut plan: Vec<SharedSynthModule> = Vec::new();
list.append(&mut modules.clone());
list.push(out.clone());
list.shuffle(&mut rng);
plan_execution(out.clone(), &list, &mut plan);
println!("---");
for (idx, module) in plan.iter().enumerate() {
indexes.insert(ByAddress(module.clone()), idx);
}
for (idx, mapping) in modules
.iter()
.map(|m| indexes.get(&ByAddress(m.clone())).unwrap())
.enumerate()
{
println!("{} -> {}", idx, mapping);
}
println!("o -> {}", indexes.get(&ByAddress(out.clone())).unwrap());
assert!(
indexes.get(&ByAddress(modules[0].clone()))
< indexes.get(&ByAddress(modules[1].clone()))
);
assert!(
indexes.get(&ByAddress(modules[1].clone()))
< indexes.get(&ByAddress(modules[2].clone()))
);
assert!(
indexes.get(&ByAddress(modules[2].clone()))
< indexes.get(&ByAddress(modules[3].clone()))
);
assert!(
indexes.get(&ByAddress(modules[3].clone())) < indexes.get(&ByAddress(out.clone()))
);
assert!(
indexes.get(&ByAddress(modules[0].clone()))
< indexes.get(&ByAddress(modules[4].clone()))
);
assert!(
indexes.get(&ByAddress(modules[4].clone()))
< indexes.get(&ByAddress(modules[3].clone()))
);
assert!(
indexes.get(&ByAddress(modules[6].clone()))
< indexes.get(&ByAddress(modules[4].clone()))
);
assert!(
indexes.get(&ByAddress(modules[5].clone()))
< indexes.get(&ByAddress(modules[6].clone()))
);
}
}
}

0 comments on commit b68e9d3

Please sign in to comment.