diff --git a/src/algorithms.rs b/src/algorithms.rs index 378d9e0..54b8cea 100644 --- a/src/algorithms.rs +++ b/src/algorithms.rs @@ -140,25 +140,21 @@ impl EGraph { } parents }); - for Class { id, nodes } in self.classes().clone().values() { - let mut other_nodes = Vec::new(); - let mut unique_node = None; - for node_id in nodes { - let node = self.nodes[node_id].clone(); - if should_split(node_id, &node) { - if let Some((other_node_id, other_node)) = unique_node { - panic!( - "Multiple nodes in one e-class should be split. E-class: {:} Node 1: {:?} {:?} Node 2: {:?} {:?}", - id, node_id, node, other_node_id, other_node - ); - } - unique_node = Some((node_id, node)); - } else { - other_nodes.push(node_id); - } + + for Class { id, nodes } in self.classes().clone().into_values() { + let (unique_nodes, other_nodes): (Vec<_>, Vec<_>) = nodes + .into_iter() + .partition(|node_id| should_split(node_id, &self.nodes[node_id])); + if unique_nodes.len() > 1 { + panic!( + "Multiple nodes in one e-class should be split. E-class: {:?} Nodes: {:?}", + id, unique_nodes + ); } - let class_data = self.class_data.get(id).cloned(); - if let Some((unique_node_id, unique_node)) = unique_node { + let unique_node = unique_nodes.into_iter().next(); + let class_data = self.class_data.get(&id).cloned(); + if let Some(unique_node_id) = unique_node { + let unique_node = self.nodes[&unique_node_id].clone(); let n_other_nodes = other_nodes.len(); let mut offset = 0; if n_other_nodes == 0 { @@ -178,7 +174,7 @@ impl EGraph { .insert(new_class_id.clone(), class_data.clone()); } // Change the e-class of the other node - self.nodes[other_node_id].eclass = new_class_id.clone(); + self.nodes[&other_node_id].eclass = new_class_id.clone(); // Create a new unique node with the same data let mut new_unique_node = unique_node.clone(); new_unique_node.eclass = new_class_id; @@ -196,7 +192,7 @@ impl EGraph { let mut new_unique_node = unique_node.clone(); new_unique_node.eclass = new_class_id; self.nodes.insert(new_id.clone().into(), new_unique_node); - for (parent_id, position) in parents.get(id).cloned().unwrap_or_default() { + for (parent_id, position) in parents.get(&id).cloned().unwrap_or_default() { changed = true; // Change the child of the parent to the new node self.nodes.get_mut(&parent_id).unwrap().children[position] =