From 78bcdb5fb0dab2046d86c5d2604c01498e5de8b6 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 29 Nov 2023 10:31:39 -0500 Subject: [PATCH] Move unextractable to table * Make row a struct so we can add attribute easier * make function and table debug friendly --- src/function/mod.rs | 77 ++++++++++++++++++----------- src/function/table.rs | 111 +++++++++++++++++++++++++++++++++--------- src/serialize.rs | 9 ++-- 3 files changed, 143 insertions(+), 54 deletions(-) diff --git a/src/function/mod.rs b/src/function/mod.rs index c8c6eff73..bc11ff656 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -22,7 +22,6 @@ pub struct Function { index_updated_through: usize, updates: usize, scratch: IndexSet, - unextractable: HashSet>, } #[derive(Clone)] @@ -62,6 +61,21 @@ impl ResolvedSchema { } } +impl Debug for Function { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Function") + .field("decl", &self.decl) + .field("schema", &self.schema) + .field("nodes", &self.nodes) + .field("indexes", &self.indexes) + .field("rebuild_indexes", &self.rebuild_indexes) + .field("index_updated_through", &self.index_updated_through) + .field("updates", &self.updates) + .field("scratch", &self.scratch) + .finish() + } +} + /// A non-Union merge discovered during rebuilding that has to be applied before /// resuming execution. pub(crate) type DeferredMerge = (ValueVec, Value, Value); @@ -143,7 +157,6 @@ impl Function { on_merge, merge_vals, }, - unextractable: Default::default(), // TODO figure out merge and default here }) } @@ -197,12 +210,12 @@ impl Function { if !self.schema.output.is_eq_sort() { panic!("Only eq sorts can be marked unextractable") } - self.unextractable.insert(inputs.to_vec()); + self.nodes.mark_unextractable(inputs); } /// Check if the given inputs are unextractable. pub fn check_unextractable(&self, inputs: &[Value]) -> bool { - self.unextractable.contains(inputs) + self.nodes.get_row(inputs).unwrap().unextractable } /// Return a column index that contains (a superset of) the offsets for the @@ -390,6 +403,7 @@ impl Function { // Entry is stale return result; }; + let unextractable = self.nodes.get_row(args).unwrap().unextractable; let mut out_val = out.value; scratch.clear(); @@ -405,36 +419,41 @@ impl Function { return result; } let out_ty = &self.schema.output; - self.nodes.insert_and_merge(scratch, timestamp, |prev| { - if let Some(mut prev) = prev { - out_ty.canonicalize(&mut prev, uf); - let mut appended = false; - if self.merge.on_merge.is_some() && prev != out_val { - deferred_merges.push((scratch.clone(), prev, out_val)); - appended = true; - } - match &self.merge.merge_vals { - MergeFn::Union => { - debug_assert!(self.schema.output.is_eq_sort()); - uf.union_values(prev, out_val, self.schema.output.name()) + self.nodes.insert_and_merge( + scratch, + timestamp, + |prev| { + if let Some(mut prev) = prev { + out_ty.canonicalize(&mut prev, uf); + let mut appended = false; + if self.merge.on_merge.is_some() && prev != out_val { + deferred_merges.push((scratch.clone(), prev, out_val)); + appended = true; } - MergeFn::AssertEq => { - if prev != out_val { - result = Err(Error::MergeError(self.decl.name, prev, out_val)); + match &self.merge.merge_vals { + MergeFn::Union => { + debug_assert!(self.schema.output.is_eq_sort()); + uf.union_values(prev, out_val, self.schema.output.name()) } - prev - } - MergeFn::Expr(_) => { - if !appended && prev != out_val { - deferred_merges.push((scratch.clone(), prev, out_val)); + MergeFn::AssertEq => { + if prev != out_val { + result = Err(Error::MergeError(self.decl.name, prev, out_val)); + } + prev + } + MergeFn::Expr(_) => { + if !appended && prev != out_val { + deferred_merges.push((scratch.clone(), prev, out_val)); + } + prev } - prev } + } else { + out_val } - } else { - out_val - } - }); + }, + unextractable, + ); if let Some((inputs, _)) = self.nodes.get_index(i) { if inputs != &scratch[..] { scratch.clear(); diff --git a/src/function/table.rs b/src/function/table.rs index 05774eb6d..a0877334f 100644 --- a/src/function/table.rs +++ b/src/function/table.rs @@ -26,6 +26,7 @@ //! It's likely that we will have to store these "on the side" or use some sort //! of persistent data-structure for the entire table. use std::{ + fmt::{Debug, Formatter}, hash::{BuildHasher, Hash, Hasher}, mem, ops::Range, @@ -46,12 +47,29 @@ struct TableOffset { off: Offset, } +#[derive(Debug, Clone)] +pub(crate) struct Row { + pub(crate) input: Input, + pub(crate) output: TupleOutput, + pub(crate) unextractable: bool, +} + +impl Row { + fn new(input: Input, output: TupleOutput, unextractable: bool) -> Row { + Row { + input, + output, + unextractable, + } + } +} + #[derive(Default, Clone)] pub(crate) struct Table { max_ts: u32, n_stale: usize, table: RawTable, - pub(crate) vals: Vec<(Input, TupleOutput)>, + pub(crate) vals: Vec, } /// Used for the RawTable probe sequence. @@ -64,12 +82,22 @@ macro_rules! search_for { } // If the hash matches, the value should not be stale, and the data // should match. - let inp = &$slf.vals[to.off as usize].0; + let inp = &$slf.vals[to.off as usize].input; inp.live() && inp.data() == $inp } }; } +impl Debug for Table { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Table") + .field("max_ts", &self.max_ts) + .field("n_stale", &self.n_stale) + .field("vals", &self.vals) + .finish() + } +} + impl Table { /// Clear the contents of the table. pub(crate) fn clear(&mut self) { @@ -89,7 +117,7 @@ impl Table { let mut src = 0usize; let mut dst = 0usize; self.table.clear(); - self.vals.retain(|(inp, _)| { + self.vals.retain(|Row { input: inp, .. }| { if inp.live() { let hash = hash_values(inp.data()); self.table @@ -108,20 +136,38 @@ impl Table { /// Get the entry in the table for the given values, if they are in the /// table. pub(crate) fn get(&self, inputs: &[Value]) -> Option<&TupleOutput> { + self.get_row(inputs).map(|row| &row.output) + } + + pub(crate) fn get_row(&self, inputs: &[Value]) -> Option<&Row> { let hash = hash_values(inputs); let TableOffset { off, .. } = self.table.get(hash, search_for!(self, hash, inputs))?; - debug_assert!(self.vals[*off].0.live()); - Some(&self.vals[*off].1) + debug_assert!(self.vals[*off].input.live()); + Some(&self.vals[*off]) + } + + pub(crate) fn mark_unextractable(&mut self, inputs: &[Value]) { + let hash = hash_values(inputs); + let TableOffset { off, .. } = self + .table + .get(hash, search_for!(self, hash, inputs)) + .unwrap(); + self.vals[*off].unextractable = true; } /// Insert the given data into the table at the given timestamp. Return the /// previous value, if there was one. pub(crate) fn insert(&mut self, inputs: &[Value], out: Value, ts: u32) -> Option { let mut res = None; - self.insert_and_merge(inputs, ts, |prev| { - res = prev; - out - }); + self.insert_and_merge( + inputs, + ts, + |prev| { + res = prev; + out + }, + false, + ); res } @@ -137,6 +183,7 @@ impl Table { inputs: &[Value], ts: u32, on_merge: impl FnOnce(Option) -> Value, + unextractable: bool, ) { assert!(ts >= self.max_ts); self.max_ts = ts; @@ -144,7 +191,11 @@ impl Table { if let Some(TableOffset { off, .. }) = self.table.get_mut(hash, search_for!(self, hash, inputs)) { - let (inp, prev) = &mut self.vals[*off]; + let Row { + input: inp, + output: prev, + .. + } = &mut self.vals[*off]; let next = on_merge(Some(prev.value)); if next == prev.value { return; @@ -153,23 +204,25 @@ impl Table { self.n_stale += 1; let k = mem::take(&mut inp.data); let new_offset = self.vals.len(); - self.vals.push(( + self.vals.push(Row::new( Input::new(k), TupleOutput { value: next, timestamp: ts, }, + unextractable, )); *off = new_offset; return; } let new_offset = self.vals.len(); - self.vals.push(( + self.vals.push(Row::new( Input::new(inputs.into()), TupleOutput { value: on_merge(None), timestamp: ts, }, + unextractable, )); self.table.insert( hash, @@ -198,7 +251,7 @@ impl Table { /// The minimum timestamp stored by the table, if there is one. pub(crate) fn min_ts(&self) -> Option { - Some(self.vals.first()?.1.timestamp) + Some(self.vals.first()?.output.timestamp) } /// An upper bound for all timestamps stored in the table. @@ -208,7 +261,7 @@ impl Table { /// Get the timestamp for the entry at index `i`. pub(crate) fn get_timestamp(&self, i: usize) -> Option { - Some(self.vals.get(i)?.1.timestamp) + Some(self.vals.get(i)?.output.timestamp) } /// Remove the given mapping from the table, returns whether an entry was @@ -221,20 +274,28 @@ impl Table { } else { return false; }; - self.vals[entry.off].0.stale_at = ts; + self.vals[entry.off].input.stale_at = ts; self.n_stale += 1; true } /// Returns the entries at the given index if the entry is live and the index in bounds. pub(crate) fn get_index(&self, i: usize) -> Option<(&[Value], &TupleOutput)> { - let (inp, out) = self.vals.get(i)?; + let Row { + input: inp, + output: out, + .. + } = self.get_index_row(i)?; if !inp.live() { return None; } Some((inp.data(), out)) } + pub(crate) fn get_index_row(&self, i: usize) -> Option<&Row> { + self.vals.get(i) + } + /// Iterate over the live entries in the table, in insertion order. pub(crate) fn iter(&self) -> impl Iterator + '_ { self.iter_range(0..self.num_offsets()) @@ -247,16 +308,22 @@ impl Table { &self, range: Range, ) -> impl Iterator + '_ { - self.vals[range.clone()] - .iter() - .zip(range) - .filter_map(|((inp, out), i)| { + self.vals[range.clone()].iter().zip(range).filter_map( + |( + Row { + input: inp, + output: out, + .. + }, + i, + )| { if inp.live() { Some((i, inp.data(), out)) } else { None } - }) + }, + ) } #[cfg(debug_assertions)] @@ -264,7 +331,7 @@ impl Table { assert!(self .vals .windows(2) - .all(|xs| xs[0].1.timestamp <= xs[1].1.timestamp)) + .all(|xs| xs[0].output.timestamp <= xs[1].output.timestamp)) } /// Iterate over the live entries in the timestamp range, passing back their diff --git a/src/serialize.rs b/src/serialize.rs index f5b8890a6..2f4c69961 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -3,7 +3,10 @@ use std::collections::VecDeque; use crate::{ ast::{FunctionDecl, Id}, - function::{table::hash_values, ValueVec}, + function::{ + table::{hash_values, Row}, + ValueVec, + }, util::HashMap, EGraph, Value, }; @@ -73,9 +76,9 @@ impl EGraph { .nodes .vals .iter() - .filter(|(i, _)| i.live()) + .filter(|Row { input, .. }| input.live()) .take(config.max_calls_per_function.unwrap_or(usize::MAX)) - .map(|(input, output)| { + .map(|Row { input, output, .. }| { ( &function.decl, &input.data,