-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
User Defined Generics #386
Comments
To make this more specific, I came up with a "reasonable" example of a rewrite that might cause this kind of recursion: # mypy: disable-error-code="empty-body"
from __future__ import annotations
from collections.abc import Callable
from egglog import *
type IntLike = i64Like | Int
class Int(Expr):
def __init__(self, i: i64Like) -> None: ...
def __add__(self, other: IntLike) -> Int: ...
def __gt__(self, other: IntLike) -> Boolean: ...
def __lt__(self, other: IntLike) -> Boolean: ...
def max(self, other: IntLike) -> Int: ...
converter(i64, Int, Int)
class Boolean(Expr):
def if_[T](self, t: T, f: T) -> T: ...
@function
def eq_fn[T](a: T, b: T) -> Boolean: ...
class Product[T, V](Expr):
def __init__(self, left: T, right: V) -> None: ...
@property
def left(self) -> T: ...
@property
def right(self) -> V: ...
def set_left(self, value: T) -> Product[T, V]: ...
def set_right(self, value: V) -> Product[T, V]: ...
def eq(self, other: Product[T, V], l_eq: Callable[[T, T], Boolean], r_eq: Callable[[V, V], Boolean]) -> Boolean: ...
class Option[T](Expr):
@classmethod
def none(cls) -> Option[T]: ...
def __init__(self, v: T) -> None: ...
def match[V](self, some: Callable[[T], V], none: V) -> V: ...
def map[V](self, fn: Callable[[T], V]) -> Option[V]: ...
class List[T](Expr):
def __init__(self) -> None: ...
def __getitem__(self, index: Int) -> T: ...
def find_index(self, fn: Callable[[T], Boolean]) -> Option[Int]: ...
def fold[V](self, f: Callable[[T, V], V], v: V) -> V: ...
def set(self, index: Int, value: T) -> List[T]: ...
def append(self, next: T) -> List[T]: ...
def most_common(self) -> Option[T]:
"""
Returns the most common element in the list.
If no items are in the list, returns None.
If multiple items are tied for most common, returns the first one.
"""
# 1. Built up list of pairs of elements and their counts
counts: List[Product[T, Int]] = self.fold(
lambda x, acc: acc.find_index(lambda p: eq_fn(p.left, x)).match(
# If we already have a pair for this element, increment the count
lambda i: acc.set(i, acc[i].set_right(acc[i].right + 1)),
# Otherwise, add a new pair with a count of 1
acc.append(Product(x, Int(1))),
),
List[Product[T, Int]](),
)
# 2. Find the highest count
highest_count = counts.fold(
lambda p, acc: Option(
acc.match(
lambda h: (p.right > h.right).if_(p, h),
p,
)
),
Option[Product[T, Int]].none(),
)
return highest_count.map(lambda p: p.left)
def _most_common_definition[T](l: List[T], r: Option[T]):
res = (
l.fold(
lambda x, acc: acc.find_index(lambda p: eq_fn(p.left, x)).match(
lambda i: acc.set(i, acc[i].set_right(acc[i].right + 1)),
acc.append(Product(x, Int(1))),
),
List[Product[T, Int]](),
)
.fold(
lambda p, acc: Option(
acc.match(
lambda h: (p.right > h.right).if_(p, h),
p,
)
),
Option[Product[T, Int]].none(),
)
.map(lambda p: p.left)
)
# the definition of most_common will be turned into this rewrite:
yield rewrite(l.most_common()).to(res)
# which will in turn be turned into this rule
yield rule(eq(r).to(l.most_common())).then(union(r).with_(res)) The way I am understanding the monomorphization approach would be to make multiple version of the So if I started with a a type Generally, we can think about the facts of a rule as input types and the actions as output types, if we are trying to analyze what types could be further created by applying rules. The only way new types can be added to the e-graph, is in the actions, not in the facts. So in this rule, in the facts, we have types But now, in the actions of this function we see we add a number of additional types, including I am only saying all this to give a vaguely realistic example of how this infinite recursion could happen to see if there is any advise on other ways to look at the problem that would allow accurate monorphization of user defined generics. A workaround here would be to cap recursing at a number of levels, say 3, so it would only generate up to three nested definitions or something. However, this would fundamentally be incomplete. |
Chatting with @ezrosent I rewrote this example in Egglog to show how generic types could work there: ;; changes to egglog
;; 1. Allow unbound types in function definitions, these will be generic params
;; 2. Allow a non zero number when defining a sort to indicate the arity of the sort
;; 3. Require sorts with non zero arity to be parameterized, i.e. (List i64)
;; 4. Allow collection primitives to be parameterized inline, i.e. (UnstableFn (i64) i64) instead of requiring them be named
;; 5. Desugar anonymous functions to named functions with a rewrite, i.e. desugar
;; (lambda (x) (replace x " " "")) to
;; (function __tmp_name (String) String)
;; (rewrite (__tmp_name x) (replace x " " ""))
(sort Boolean)
(function if (Boolean T T) T)
(sort Int)
(function IntInit (i64) Int)
(function IntAdd (Int Int) Int)
(function IntGt (Int Int) Boolean)
(function IntLt (Int Int) Boolean)
(function IntMax (Int Int) Int)
(function eq (T T) Boolean)
(sort Product 2)
(function ProductInit (T V) (Product T V))
(function ProductLeft (Product T V) T)
(function ProductRight (Product T V) V)
(function ProductSetLeft (Product T V T) (Product T V))
(function ProductSetRight (Product T V V) (Product T V))
(sort Option 1)
(function OptionNone () (Option T))
(function OptionSome (T) (Option T))
(function OptionMatch ((Option T) (UnstableFn (T) V) V) V)
(function OptionMap ((Option T) (UnstableFn (T) V)) (Option V))
(sort List 1)
(function ListInit () (List T))
(function ListGet ((List T) Int) T)
(function ListFindIndex ((List T) (UnstableFn (T) Boolean)) (Option Int))
(function ListFold ((List T) (UnstableFn (T V) V) V) V)
(function ListSet ((List T) Int T) (List T))
(function ListAppend ((List T) T) (List T))
(function ListMostCommon ((List T)) (Option T))
(rewrite
(ListMostCommon l)
(OptionMap
;; Find the highest count
(ListFold
;; Built up list of pairs of elements and their counts
(ListFold
(lambda (x acc)
(OptionMatch
(ListFindIndex
(lambda (p) (eq (ProductLeft p) x))
)
;; If we already have a pair for this element, increment the count
(lambda (i)
(ListSet acc i (ProductSetRight (ListGet acc i) (IntAdd (ProductRight (ListGet acc i)) (Int 1))))
)
;; Otherwise, add a new pair with a count of 1
(ListAppend acc (Product x (Int 1)))
)
)
(ListInit)
)
(lambda (p acc)
(OptionSome
(OptionMatch
acc
(lambda (h)
(if (IntGt (ProductRight p) (ProductRight h))
p
h
)
)
p
)
)
)
(OptionNone)
)
(lambda (p) (ProductLeft p))
)
) In this case, if I add a
In the actions of the rule we have types (at least):
So if we start with types My assumption here is we are just analyzing every rule at the type level, not looking at any of the actual contents of the rules. EDIT: Shorter ExampleHere is a shorter example that isn't "realistic" but is much easier to read: (Sort A 1)
(function CreateA (T) (A T))
(function CreateA2 (T) (A T))
(rule
((CreateA x))
((CreateA2 (CreateA2 x)))
)
(CreateA 1)
;; If we look at this program, we start with types:
;; (A i64)
;; We see that the rule has (A T) in the facts and (A (A T)) in the body
;; so then we match on (A i64) and create (A (A i64)) and the rule for it:
(Sort A_Int)
(function CreateA_Int (Int) (A_Int Int))
(function CreateA_Int2 (Int) (A_Int Int))
(rule
((CreateA_Int x))
((CreateA_Int2 (CreateA_Int2 x)))
)
;; However, now we recurse, and since we don't analyze the rule at the function level, only the type level,
;; we don't know that we only need to recurse once, so we would make A_int_int
(sort A_Int_Int)
(function CreateA_Int_Int (A_Int) (A_Int_Int A_Int))
(function CreateA_Int_Int2 (A_Int) (A_Int_Int A_Int))
(rule
((CreateA_Int_Int x))
((CreateA_Int_Int2 (CreateA_Int_Int2 x)))
) |
I like the analogy to monomorphization upthread, but I don't think it makes sense to treat monomorphization for egglog as instantiating a function or rule for every type present in the program. In a language like Rust, you'd only create a function like What's weird with egglog is that functions aren't explicitly called: they get called "automatically" based on the contents of the database. But I still think the the rules for instantiation should be at the 'function-level', not the type-level as Saul's pointing out here. We talked offline about having a sort of fixpoint computation of which types are needed. One way to think about doing that is to rewrite the rules where all the values are erased, but the types show up explicitly, e.g.:
I think for any seed values from the database, these rules will saturate and we can read off the contents of the "type-level" One thing that we'd need to do here is to do here is to take steps to guarantee saturation. A few ideas:
|
I think under this instantiation-as-you-go model, the current design around rule sets and schedules does not play well with constructors like
|
@yihozhang I would disallow things like I think if we want to allow the semantics of that rule but with a particular type in mind, we have to allow explicit typing in the frontend. In the Python bindings, this is required (i.e. I wouldn't allow you to do So we would want to add a way in the syntax to explicitly parameterized the types. I could add this to the examples if it's helpful... @ezrosent I can see the rough outline of what you mean by doing this at a per function level... Which would allow some more precision that a type level, but as you say, we could still end up the position where it's not entirely precise. Moreso, it seems like a lot of complexity to add.... And makes me more curious about exploring a non-monorphization approach where rules that have type vars are preserved like that into the query compiler, and it's only when the query is matched that actually parameterizes the types? |
I am opening up this issue to discuss how to support user-defined generics within egglog.
Motiviation
Now that first-class functions are working in the egglog Python bindings and I am starting to use them, it's becoming harder to avoid having user defined generics. For example, here is a somewhat minimal tuple of int's class:
This code has to be repeated for every other tuple (tuple of booleans, tuple of ndarrays, etc).
Not only that, you can see I have two repeated
fold
functions, one for folding a tuple of ints to ints, and another to bools. The burden is high enough that it makes a whole host of functional programming like tasks quite a chore instead of being relatively smooth.Without user defined generics, it also makes it much less attactive to provide any shareable stdlib like module, with all of these collections, since they are much less general purpose.
A generic Tuple class could look something like this:
Approach
I understand it would be a rather large lift to add them to egglog itself. I started a rough design in this PR #299 A suggestion there was to try doing all of this as desugaring instead in Python package.
That approach would work something like creating a duplicate actual sort for each instantiation of a generic type and also duplicating all rewrite rules.
I am still a bit unclear on how the details would work though. I assume when we add rules and then run rulesets, we would have to make many instances of each generic type and every rewrite rule that uses it. But it seems hard to know when to stop generating new rules.
For example, if there is a rule with facts that depend on
Generic1[A]
andGeneric1[B]
, and actions that giveGeneric2[A, B]
do we parametrizeA
andB
with every single sort that is registered? How would we limit the infinite recursion then? (i.eGeneric1[Generic1[int]]
, etc).If anyone has a sound way of thinking about it, that would be much appreciated.
The text was updated successfully, but these errors were encountered: