Skip to content

Commit

Permalink
Make raw_union more flexible and add a fallible try_raw_rebuild
Browse files Browse the repository at this point in the history
  • Loading branch information
dewert99 committed Mar 21, 2024
1 parent c18f6d4 commit fb07f3b
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 28 deletions.
25 changes: 13 additions & 12 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -804,20 +804,21 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {

self.clean = false;
let mut new_root = None;
self.inner
.raw_union(enode_id1, enode_id2, |class1, id1, p1, class2, _, p2| {
new_root = Some(id1);
self.inner.raw_union(enode_id1, enode_id2, |info| {
new_root = Some(info.id1);

let did_merge = self.analysis.merge(&mut class1.data, class2.data);
if did_merge.0 {
self.analysis_pending.extend(p1);
}
if did_merge.1 {
self.analysis_pending.extend(p2);
}
let did_merge = self.analysis.merge(&mut info.data1.data, info.data2.data);
if did_merge.0 {
self.analysis_pending
.extend(info.parents1.into_iter().copied());
}
if did_merge.1 {
self.analysis_pending
.extend(info.parents2.into_iter().copied());
}

concat_vecs(&mut class1.nodes, class2.nodes);
});
concat_vecs(&mut info.data1.nodes, info.data2.nodes);
});
if let Some(id) = new_root {
if let Some(explain) = &mut self.explain {
explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs);
Expand Down
78 changes: 62 additions & 16 deletions src/raw/egraph.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{raw::RawEClass, Dot, HashMap, Id, Language, RecExpr, UnionFind};
use std::convert::Infallible;
use std::ops::{Deref, DerefMut};
use std::{
borrow::BorrowMut,
Expand Down Expand Up @@ -426,6 +427,26 @@ impl<L: Language, D> RawEGraph<L, D> {
}
}

/// Information for [`RawEGraph::raw_union`] callback
#[non_exhaustive]
pub struct MergeInfo<'a, D: 'a> {
/// id that will be the root for the newly merged eclass
pub id1: Id,
/// data associated with `id1` that can be modified to reflect `data2` being merged into it
pub data1: &'a mut D,
/// parents of `id1` before the merge
pub parents1: &'a [Id],
/// id that used to be a root but will now be in `id1` eclass
pub id2: Id,
/// data associated with `id2`
pub data2: D,
/// parents of `id2` before the merge
pub parents2: &'a [Id],
/// true if `id1` was the root of the second id passed to [`RawEGraph::raw_union`]
/// false if `id1` was the root of the first id passed to [`RawEGraph::raw_union`]
pub swapped_ids: bool,
}

impl<L: Language, D> RawEGraph<L, D> {
/// Adds `enode` to a [`RawEGraph`] contained within a wrapper type `T`
///
Expand Down Expand Up @@ -524,7 +545,7 @@ impl<L: Language, D> RawEGraph<L, D> {
&mut self,
enode_id1: Id,
enode_id2: Id,
merge: impl FnOnce(&mut D, Id, Parents<'_>, D, Id, Parents<'_>),
merge: impl FnOnce(MergeInfo<'_, D>),
) {
let mut id1 = self.find_mut(enode_id1);
let mut id2 = self.find_mut(enode_id2);
Expand All @@ -534,7 +555,9 @@ impl<L: Language, D> RawEGraph<L, D> {
// make sure class2 has fewer parents
let class1_parents = self.classes[&id1].parents.len();
let class2_parents = self.classes[&id2].parents.len();
let mut swapped = false;
if class1_parents < class2_parents {
swapped = true;
std::mem::swap(&mut id1, &mut id2);
}

Expand All @@ -545,22 +568,22 @@ impl<L: Language, D> RawEGraph<L, D> {
let class2 = self.classes.remove(&id2).unwrap();
let class1 = self.classes.get_mut(&id1).unwrap();
assert_eq!(id1, class1.id);
let (p1, p2) = (Parents(&class1.parents), Parents(&class2.parents));
merge(
&mut class1.raw_data,
class1.id,
p1,
class2.raw_data,
class2.id,
p2,
);
let info = MergeInfo {
id1: class1.id,
data1: &mut class1.raw_data,
parents1: &class1.parents,
id2: class2.id,
data2: class2.raw_data,
parents2: &class2.parents,
swapped_ids: swapped,
};
merge(info);

self.pending.extend(&class2.parents);

class1.parents.extend(class2.parents);
}

#[inline]
/// Rebuild to [`RawEGraph`] to restore congruence closure
///
/// ## Parameters
Expand All @@ -576,25 +599,48 @@ impl<L: Language, D> RawEGraph<L, D> {
/// In order to be correct `perform_union` should call [`raw_union`](RawEGraph::raw_union)
///
/// ### `handle_pending`
/// Called with the uncanonical id of each enode whose canonical children have changned, along with a canonical
/// Called with the uncanonical id of each enode whose canonical children have changed, along with a canonical
/// version of it
#[inline]
pub fn raw_rebuild<T>(
outer: &mut T,
get_self: impl Fn(&mut T) -> &mut Self,
mut perform_union: impl FnMut(&mut T, Id, Id),
mut handle_pending: impl FnMut(&mut T, Id, &L),
handle_pending: impl FnMut(&mut T, Id, &L),
) {
let _: Result<(), Infallible> = RawEGraph::try_raw_rebuild(
outer,
get_self,
|this, id1, id2| Ok(perform_union(this, id1, id2)),
handle_pending,
);
}

/// Similar to [`raw_rebuild`] but allows for the union operation to fail and abort the rebuild
#[inline]
pub fn try_raw_rebuild<T, E>(
outer: &mut T,
get_self: impl Fn(&mut T) -> &mut Self,
mut perform_union: impl FnMut(&mut T, Id, Id) -> Result<(), E>,
mut handle_pending: impl FnMut(&mut T, Id, &L),
) -> Result<(), E> {
loop {
let this = get_self(outer);
if let Some(class) = this.pending.pop() {
let mut node = this.id_to_node(class).clone();
node.update_children(|id| this.find_mut(id));
handle_pending(outer, class, &node);
if let Some(memo_class) = get_self(outer).residual.memo.insert(node, class) {
perform_union(outer, memo_class, class);
match perform_union(outer, memo_class, class) {
Ok(()) => {}
Err(e) => {
get_self(outer).pending.push(class);
return Err(e);
}
}
}
} else {
break;
break Ok(());
}
}
}
Expand Down Expand Up @@ -638,7 +684,7 @@ impl<L: Language> RawEGraph<L, ()> {
/// Simplified version of [`raw_union`](RawEGraph::raw_union) for egraphs without eclass data
pub fn union(&mut self, id1: Id, id2: Id) -> bool {
let mut unioned = false;
self.raw_union(id1, id2, |_, _, _, _, _, _| {
self.raw_union(id1, id2, |_| {
unioned = true;
});
unioned
Expand Down

0 comments on commit fb07f3b

Please sign in to comment.