Skip to content

Commit

Permalink
exp tdim caching
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Dec 9, 2024
1 parent 47076fe commit cd982d8
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 30 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ prost-types = "0.11.0"
py_literal = "0.4.0"
rand = { version = "0.8.4", features = ["small_rng"] }
rand_distr = "0.4"
rapidhash = "1.2"
rayon = "1.10"
readings-probe = "0.1.3"
regex = "1.5.4"
Expand Down
3 changes: 2 additions & 1 deletion data/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ keywords = [ "TensorFlow", "NeuralNetworks" ]
categories = [ "science" ]
autobenches = false
edition = "2021"
rust-version = "1.75"
rust-version = "1.77"

[badges]
maintenance = { status = "actively-developed" }
Expand All @@ -27,6 +27,7 @@ nom.workspace = true
num-complex = { workspace = true, optional = true }
num-integer.workspace = true
num-traits.workspace = true
rapidhash.workspace = true
smallvec.workspace = true
lazy_static.workspace = true
scan_fmt.workspace = true
Expand Down
62 changes: 33 additions & 29 deletions data/src/dim/sym.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub struct SymbolScopeData {
table: DefaultStringInterner,
assertions: Vec<Assertion>,
scenarios: Vec<(String, Vec<Assertion>)>,
cache_proven_positive_or_zero: RefCell<rapidhash::RapidHashMap<TDim, bool>>,
}

impl SymbolScope {
Expand Down Expand Up @@ -147,7 +148,7 @@ impl SymbolScope {
let locked = self.0.lock();
let locked = locked.borrow();
if locked.scenarios.len() == 0 {
return Ok(None)
return Ok(None);
}
let mut maybe = None;
for (ix, (_name, assertions)) in locked.scenarios.iter().enumerate() {
Expand All @@ -158,7 +159,7 @@ impl SymbolScope {
} else if maybe.is_none() {
maybe = Some(ix);
} else {
return Ok(None)
return Ok(None);
}
}
if maybe.is_some() {
Expand Down Expand Up @@ -201,38 +202,41 @@ impl SymbolScopeData {
if let TDim::Val(v) = t {
return *v >= 0;
}
let positives = self.assertions.iter().filter_map(|i| i.as_known_positive()).collect_vec();
let mut visited = vec![];
let mut todo = vec![t.clone()];
while let Some(t) = todo.pop() {
if t.to_i64().is_ok_and(|i| i >= 0) {
return true;
}
if t.inclusive_bound(self, false).is_some_and(|l| l >= 0) {
return true;
}
let syms = t.symbols();
for s in syms {
let me = t.guess_slope(&s);
for pos in &positives {
if pos.symbols().contains(&s) {
let other = pos.guess_slope(&s);
if me.0.signum() == other.0.signum() {
let new = t.clone() * me.1 * other.0.abs()
- pos.clone() * me.0.abs() * other.1;
if !visited.contains(&new) {
todo.push(new);
*self.cache_proven_positive_or_zero.borrow_mut().entry(t.clone()).or_insert_with(|| {
let positives =
self.assertions.iter().filter_map(|i| i.as_known_positive()).collect_vec();
let mut visited = vec![];
let mut todo = vec![t.clone()];
while let Some(t) = todo.pop() {
if t.to_i64().is_ok_and(|i| i >= 0) {
return true;
}
if t.inclusive_bound(self, false).is_some_and(|l| l >= 0) {
return true;
}
let syms = t.symbols();
for s in syms {
let me = t.guess_slope(&s);
for pos in &positives {
if pos.symbols().contains(&s) {
let other = pos.guess_slope(&s);
if me.0.signum() == other.0.signum() {
let new = t.clone() * me.1 * other.0.abs()
- pos.clone() * me.0.abs() * other.1;
if !visited.contains(&new) {
todo.push(new);
}
}
}
}
}
visited.push(t);
if visited.len() > 10 {
break;
}
}
visited.push(t);
if visited.len() > 10 {
break;
}
}
false
false
})
}
}

Expand Down

0 comments on commit cd982d8

Please sign in to comment.