Presentation as HTML

from IPython.display import YouTubeVideo

YouTubeVideo("Pbi2uV9vWPg")

EGraphs in Python#

  • Overview of the ecosystem

  • What is an e-graph?

  • What is egglog?

  • What are some possible applications in the PyData world?

../_images/3d12ba07211541afa7813aa0c1f803e6f882ee42ade3f20f5136fa2957b76ff6.svg

Saul Shanabrook - July 20, 2023

Open Source Data Science Ecosystem in Python#

The term “ecosystem” is often used to describe the modern open-source scientific software. In biology, the term “ecosystem” is defined as a biological community of interacting organisms and their physical environment. Modern open-source scientific software development occurs in a similarly interconnected and interoperable fashion.

from Jupyter Meets the Earth: Ecosystem

Aims#

  • How can the tools we build foster greater resiliancy, collaboration, and interdependence in this ecosystem?

  • How can they help it stay flexible enough to adapt to the changing computational landscape to empower users and authors?

What role could egglog play?#

  • Bring the programming languages community closer to this space, providing theoretical frameworks for thinking about composition and language.

  • Constrained type system could support decentralized interopability and composition between data science libraries.

Other Python EGraph Libraries#

TODO: Put this first, Say it’s for library authors

Semantics of python and egglog

  • Started with snake-egg

  • Didn’t want to re-invent the wheel, stay abreast of recent developments and research

  • Second piece that interests me

    • Unlike egg there are some builtin sorts, and can build user defined sorts on top of those

    • No host language conditions or data structures

    • Helps with optimization, more constrained

    • -> De-centers algorithms based on value, move to based on type. Everything becomes an interface.

    • Social dynamics, goal is ability to inovate and experiment, while still supporting existing use cases

      • New dataframe library comes out, supporting custom hardware. How dow we use it without rewriting code?

      • How do we have healthy ecosystem within these tools? Power

      • If it’s too hard, encourages centralized monopolistic actors to step in provide one stop shop solutions for users.

    • Active problem in the community, with things like trying to standardize on interop.

    • Before getting too abstract, let’s go to an example

What is an e-graph?#

E-graphs are these super wonderful data structures for managing equality and equivalence information. They are traditionally used inside of constraint solvers and automated theorem provers to implement congruence closure, an efficient algorithm for equational reasoning—but they can also be used to implement rewrite systems.

Talia Ringer - “Proof Automation” course

  • Come from automated theorum proving world

  • Something congruence, which is like how triangles are similar but not equal,

  • Can be used for term rewriting

In abstract algebra, a congruence relation (or simply congruence) is an equivalence relation on an algebraic structure (such as a group, ring, or vector space) that is compatible with the structure in the sense that algebraic operations done with equivalent elements will yield equivalent elements.

Wikipedia - Congruence relation

  • doing one thing on a set will be the same as doing it on another

Congruence Closure#

from __future__ import annotations
from egglog import *

egraph = EGraph()


@egraph.class_
class Structure(Expr):
    ...


a = egraph.constant("a", Structure)
b = egraph.constant("b", Structure)
c = egraph.constant("c", Structure)


@egraph.function
def operation(l: Structure, r: Structure) -> Structure:
    ...


a_b = egraph.let("a_b", operation(a, b))
a_c = egraph.let("a_c", operation(a, c))

egraph
outer_cluster_0 cluster_0 outer_cluster_1 cluster_1 outer_cluster_3 cluster_3 outer_cluster_2 cluster_2 outer_cluster_4 cluster_4 operation_5871781006564002453:s->a_0 operation_5871781006564002453:s->b_0 operation_17615343019692007359:s->a_0 operation_17615343019692007359:s->c_0 a_0 a b_0 b c_0 c operation_5871781006564002453 operation a_b_0 a_b operation_17615343019692007359 operation a_c_0 a_c
  • Define a structure, define an operation on that structure

  • Define two elements

  • Are they equal?

egraph.check(eq(a_b).to(a_c))
---------------------------------------------------------------------------
EggSmolError                              Traceback (most recent call last)
Cell In[4], line 1
----> 1 egraph.check(eq(a_b).to(a_c))

File ~/p/egg-smol-python/python/egglog/egraph.py:721, in EGraph.check(self, *facts)
    717 def check(self, *facts: FactLike) -> None:
    718     """
    719     Check if a fact is true in the egraph.
    720     """
--> 721     self._process_commands([self._facts_to_check(facts)])

File ~/p/egg-smol-python/python/egglog/egraph.py:634, in EGraph._process_commands(self, commands)
    633 def _process_commands(self, commands: Iterable[bindings._Command]) -> None:
--> 634     self._egraph.run_program(*commands)

EggSmolError: Check failed: 
(= a_b a_c)
egraph.register(union(b).with_(c))
egraph.run(1)
egraph.check(eq(a_b).to(a_c))
egraph
outer_cluster_0 cluster_0 outer_cluster_2 cluster_2 outer_cluster_1 cluster_1 operation_5871781006564002453:s->a_0 operation_5871781006564002453:s->b_0 a_0 a operation_5871781006564002453 operation a_b_0 a_b a_c_0 a_c c_0 c b_0 b

We just used egglog to

  • Define a structure

  • Define some operation on that structure

  • See it give us a congruence relation on that structure

    • “algebraic operations done with equivalent elements will yield equivalent elements”

    • We call a set of equivalent elements an “e-class” i.e. {a} or {b, c}

    • We only store pointers to other e-classes, not elements

  • Combining elements… so that operations done on two elements of the same

“Rewrite systems”?#

@egraph.register
def _operation_commutative(left: Structure, right: Structure):
    yield rewrite(operation(left, right)).to(operation(right, left))


egraph.run(run().saturate())
egraph.check(eq(operation(a, b)).to(operation(b, a)))
egraph
outer_cluster_0 cluster_0 outer_cluster_2 cluster_2 outer_cluster_1 cluster_1 operation_5871781006564002453:s->a_0 operation_5871781006564002453:s->b_0 operation_956286968014291186:s->a_0 operation_956286968014291186:s->c_0 a_0 a operation_5871781006564002453 operation operation_956286968014291186 operation a_b_0 a_b a_c_0 a_c c_0 c b_0 b

egglog can also be used for rewrite system:

  • Define rules which will be matched on the e-graph

  • The results will be added to the e-graph

  • Can keep runnning these rules until the e-graph is “saturated”

@egraph.register
def _operation_identity(s: Structure):
    yield rewrite(operation(c, s)).to(c)


egraph.run(run().saturate())
egraph.extract(operation(c, a))
b
outer_cluster_2 cluster_2 outer_cluster_0 cluster_0 operation_1912573936028582372:s->c_0 operation_1912573936028582372:s->a_0 operation_11743562013128004906:s->a_b_0 operation_11743562013128004906:s->a_0 a_c_0 a_c b_0 b operation_1912573936028582372 operation operation_11743562013128004906 operation a_b_0 a_b c_0 c a_0 a
  • Define c as the identity for this operation

  • “extract” an expression to find lowest cost equivalent expression

What are e-graphs?

  • Data structure to hold “congruence closure” i.e. sets of equivalent items

  • Can implement term rewriting system on them

    • Order of rewrites don’t matter b/c application only add information

    • Extract “best” expression from e-graph after applying rules

What is egglog?#

  • Will show how some examples translate

Sorts, expressions, and functions#

%%egglog graph
(datatype Math
  (Num i64)
  (Var String)
  (Add Math Math)
  (Mul Math Math))

(define expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))
(define expr2 (Add (Num 6) (Mul (Num 2) (Var "x"))))
../_images/ed2c336777e2d3dc679bbb3689e7b59bbfc0b9d0883bd41638a20c6c904b9c0d.svg
  • User defined sorts

  • Expressions

    • expr1 and expr2 in their own e-classes, we haven’t ran any rules

  • %%egglog magic, Writing egglog in Notebook, graphs, output inline.

egraph = EGraph()


@egraph.class_
class Num(Expr):
    @classmethod
    def var(cls, name: StringLike) -> Num:
        ...

    def __init__(self, value: i64Like) -> None:
        ...

    def __add__(self, other: Num) -> Num:
        ...

    def __mul__(self, other: Num) -> Num:
        ...


expr1 = egraph.let("expr1", Num(2) * (Num.var("x") + Num(3)))
expr2 = egraph.let("expr2", Num(6) + Num(2) * Num.var("x"))
egraph
outer_cluster_Num.__init___17615343019692007359_0 cluster_Num.__init___17615343019692007359_0 outer_cluster_3 cluster_3 outer_cluster_Num.__init___16783941965674463102_0 cluster_Num.__init___16783941965674463102_0 outer_cluster_Num.__init___11743562013128004906_0 cluster_Num.__init___11743562013128004906_0 outer_cluster_Num.var_7848520443469635519_0 cluster_Num.var_7848520443469635519_0 outer_cluster_6 cluster_6 outer_cluster_7 cluster_7 outer_cluster_0 cluster_0 outer_cluster_1 cluster_1 outer_cluster_4 cluster_4 outer_cluster_2 cluster_2 outer_cluster_5 cluster_5 Num.__add___7659469028595837896:s->Num.var_7848520443469635519 Num.__add___7659469028595837896:s->Num.__init___17615343019692007359 Num.__mul___5871781006564002453:s->Num.__init___11743562013128004906 Num.__mul___5871781006564002453:s->Num.var_7848520443469635519 Num.__add___13095445380246898500:s->Num.__mul___5871781006564002453 Num.__add___13095445380246898500:s->Num.__init___16783941965674463102 Num.__init___11743562013128004906:s->Num.__init___11743562013128004906_0 Num.var_7848520443469635519:s->Num.var_7848520443469635519_0 Num.__mul___17615343019692007359:s->Num.__add___7659469028595837896 Num.__mul___17615343019692007359:s->Num.__init___11743562013128004906 Num.__init___17615343019692007359:s->Num.__init___17615343019692007359_0 Num.__init___16783941965674463102:s->Num.__init___16783941965674463102_0 Num.__init___17615343019692007359_0 3 Num.__init___16783941965674463102_0 6 Num.__init___11743562013128004906_0 2 Num.var_7848520443469635519_0 "x" Num.__add___7659469028595837896 Num.__add__ Num.__mul___5871781006564002453 Num.__mul__ Num.__add___13095445380246898500 Num.__add__ expr2_0 expr2 Num.__init___11743562013128004906 Num.__init__ Num.var_7848520443469635519 Num.var expr1_0 expr1 Num.__mul___17615343019692007359 Num.__mul__ Num.__init___17615343019692007359 Num.__init__ Num.__init___16783941965674463102 Num.__init__
  • Re-use existing Python class and functions

    • Humans and computers to understand the typing semantics

    • Humans read __init__ and __add__.

    • Static type checkers. Num("String") it won’t work.

    • Static type checking drives much of the API design of the library

  • Operator overloading support infix operators

  • Names generated based on classes

    • Same operator on different types compile to different function with different signature

Rewrite rules and checks#

%%egglog graph continue
(rewrite (Add a b)
         (Add b a))
(rewrite (Mul a (Add b c))
         (Add (Mul a b) (Mul a c)))
(rewrite (Add (Num a) (Num b))
         (Num (+ a b)))
(rewrite (Mul (Num a) (Num b))
         (Num (* a b)))

(run 10)
(check (= expr1 expr2))
../_images/07844c379cfeccd41bfd5a1faf60083060d18ebf9080ab5358c7043e187e43d6.svg
  • See equivalent, in same e-class now

@egraph.register
def _(a: Num, b: Num, c: Num, i: i64, j: i64):
    yield rewrite(a + b).to(b + a)
    yield rewrite(a * (b + c)).to((a * b) + (a * c))
    yield rewrite(Num(i) + Num(j)).to(Num(i + j))
    yield rewrite(Num(i) * Num(j)).to(Num(i * j))


egraph.run(10)
egraph.check(eq(expr1).to(expr2))
egraph
outer_cluster_Num.__init___17615343019692007359_0 cluster_Num.__init___17615343019692007359_0 outer_cluster_2 cluster_2 outer_cluster_Num.__init___16783941965674463102_0 cluster_Num.__init___16783941965674463102_0 outer_cluster_Num.__init___11743562013128004906_0 cluster_Num.__init___11743562013128004906_0 outer_cluster_Num.var_7848520443469635519_0 cluster_Num.var_7848520443469635519_0 outer_cluster_0 cluster_0 outer_cluster_1 cluster_1 outer_cluster_4 cluster_4 outer_cluster_6 cluster_6 outer_cluster_10 cluster_10 outer_cluster_3 cluster_3 Num.__init___17615343019692007359:s->Num.__init___17615343019692007359_0 Num.__init___11743562013128004906:s->Num.__init___11743562013128004906_0 Num.var_7848520443469635519:s->Num.var_7848520443469635519_0 Num.__mul___17615343019692007359:s->Num.__init___11743562013128004906 Num.__mul___17615343019692007359:s->Num.__add___7784354942592584825 Num.__add___9842753449732275747:s->Num.__mul___5871781006564002453 Num.__add___9842753449732275747:s->Num.__init___16783941965674463102 Num.__add___11849178328430774015:s->Num.__mul___5871781006564002453 Num.__add___11849178328430774015:s->Num.__mul___11743562013128004906 Num.__mul___5871781006564002453:s->Num.__init___11743562013128004906 Num.__mul___5871781006564002453:s->Num.var_7848520443469635519 Num.__mul___11743562013128004906:s->Num.__init___17615343019692007359 Num.__mul___11743562013128004906:s->Num.__init___11743562013128004906 Num.__init___16783941965674463102:s->Num.__init___16783941965674463102_0 Num.__add___7659469028595837896:s->Num.__init___17615343019692007359 Num.__add___7659469028595837896:s->Num.var_7848520443469635519 Num.__add___7784354942592584825:s->Num.__init___17615343019692007359 Num.__add___7784354942592584825:s->Num.var_7848520443469635519 Num.__init___17615343019692007359_0 3 Num.__init___16783941965674463102_0 6 Num.__init___11743562013128004906_0 2 Num.var_7848520443469635519_0 "x" Num.__init___17615343019692007359 Num.__init__ Num.__init___11743562013128004906 Num.__init__ Num.var_7848520443469635519 Num.var expr1_0 expr1 Num.__mul___17615343019692007359 Num.__mul__ Num.__add___9842753449732275747 Num.__add__ Num.__add___11849178328430774015 Num.__add__ expr2_0 expr2 Num.__mul___5871781006564002453 Num.__mul__ Num.__mul___11743562013128004906 Num.__mul__ Num.__init___16783941965674463102 Num.__init__ Num.__add___7659469028595837896 Num.__add__ Num.__add___7784354942592584825 Num.__add__
  • Similar in Python, rewrite rules, run, check

  • Notice that all vars need types, unlike inferred in egglog

    • Both for static type checkers to verify

    • And for runtime to know what methods

Extracting lowest cost expression#

%%egglog continue output
(extract expr1)
Extracted with cost 8: (Mul (Num 2) (Add (Var "x") (Num 3)))
  • Extract lowest cost expr

egraph.extract(expr1)
(Num(2) * Num.var("x")) + Num(6)
  • get back expr object

  • Str representation is Python syntax

Multipart Rules#

%%egglog graph
(function fib (i64) i64)

(set (fib 0) 0)
(set (fib 1) 1)
(rule ((= f0 (fib x))
       (= f1 (fib (+ x 1))))
      ((set (fib (+ x 2)) (+ f0 f1))))

(run 7)
(check (= (fib 7) 13))
../_images/ee6cb5b59adf41dfc922b6222ca7c9d3e98bbf097710d5ed2c6a7c0d89c8a978.svg
  • Rule that depend on facts and execute actions

fib_egraph = EGraph()


@fib_egraph.function
def fib(x: i64Like) -> i64:
    ...


@fib_egraph.register
def _(f0: i64, f1: i64, x: i64):
    yield set_(fib(0)).to(i64(1))
    yield set_(fib(1)).to(i64(1))
    yield rule(
        eq(f0).to(fib(x)),
        eq(f1).to(fib(x + 1)),
    ).then(set_(fib(x + 2)).to(f0 + f1))


fib_egraph.run(7)
fib_egraph.check(eq(fib(7)).to(i64(21)))
  • set_ and and eq both type safe. Required builder syntax

Include & Modules#

%%writefile path.egg
(relation path (i64 i64))
(relation edge (i64 i64))

(rule ((edge x y))
      ((path x y)))

(rule ((path x y) (edge y z))
      ((path x z)))
Overwriting path.egg
%%egglog
(include "path.egg")
(edge 1 2)
(edge 2 3)
(edge 3 4)
(run 3)
(check (path 1 3))
  • Include another file for re-useability

mod = Module()
path = mod.relation("path", i64, i64)
edge = mod.relation("edge", i64, i64)


@mod.register
def _(x: i64, y: i64, z: i64):
    yield rule(edge(x, y)).then(path(x, y))
    yield rule(path(x, y), edge(y, z)).then(path(x, z))
  • Modules same in Python

  • Supports defining rules, etc, but doesn’t actually run them, just builds up commands

egraph = EGraph([mod])
egraph.register(edge(1, 2), edge(2, 3), edge(3, 4))
egraph.run(3)
egraph.check(path(1, 3))
  • Then when we depend on them, it will run those commands first.

  • Allows distribution of code and others to re-use it, using existing Python import mechanisms.

Python Objects#

egglog supports i64, f64, String, Vec, Map, Set

But what if I want to store a Python object in my e-graph?

egraph = EGraph()
res = egraph.save_object(True)
# Saves reference to object and stores in e-graph as type and value hashes:
res
PyObject(277923772, 1)
egraph.load_object(res)
True

Can eval arbitrary python code:

empty_dict = egraph.save_object({})
locals_ = empty_dict.dict_update(egraph.save_object("x"), res)

egraph.load_object(egraph.extract(py_eval("not x", locals_, empty_dict)))

“Preserved” methods#

…from egglog to Python…

egraph = EGraph()


@egraph.class_
class Bool(Expr):
    def to_py(self) -> PyObject:
        ...

    def __or__(self, other: Bool) -> Bool:
        ...

    # This will get executed eagerly
    @egraph.method(preserve=True)
    def __bool__(self) -> bool:
        print(self)
        egraph.register(self)
        egraph.run(run(limit=10).saturate())
        print(f"   -> {egraph.extract(self)}")
        return egraph.load_object(egraph.extract(self.to_py()))


TRUE = egraph.constant("TRUE", Bool)
FALSE = egraph.constant("FALSE", Bool)
bool(TRUE | FALSE)
TRUE | FALSE
   -> TRUE | FALSE
---------------------------------------------------------------------------
EggSmolError                              Traceback (most recent call last)
Cell In[25], line 1
----> 1 bool(TRUE | FALSE)

File ~/p/egg-smol-python/python/egglog/runtime.py:403, in _preserved_method(self, __name)
    401 except KeyError:
    402     raise TypeError(f"{self.__egg_typed_expr__.tp.name} has no method {__name}")
--> 403 return method(self)

Cell In[24], line 15, in Bool.__bool__(self)
     13 egraph.run(run(limit=10).saturate())
     14 print(f"   -> {egraph.extract(self)}")
---> 15 return egraph.load_object(egraph.extract(self.to_py()))

File ~/p/egg-smol-python/python/egglog/egraph.py:739, in EGraph.extract(self, expr)
    737 typed_expr = expr_parts(expr)
    738 egg_expr = typed_expr.to_egg(self._mod_decls)
--> 739 extract_report = self._run_extract(egg_expr, 0)
    740 new_typed_expr = TypedExprDecl.from_egg(self._mod_decls, extract_report.expr)
    741 if new_typed_expr.tp != typed_expr.tp:

File ~/p/egg-smol-python/python/egglog/egraph.py:756, in EGraph._run_extract(self, expr, n)
    755 def _run_extract(self, expr: bindings._Expr, n: int) -> bindings.ExtractReport:
--> 756     self._process_commands([bindings.Extract(n, expr)])
    757     extract_report = self._egraph.extract_report()
    758     if not extract_report:

File ~/p/egg-smol-python/python/egglog/egraph.py:634, in EGraph._process_commands(self, commands)
    633 def _process_commands(self, commands: Iterable[bindings._Command]) -> None:
--> 634     self._egraph.run_program(*commands)

EggSmolError: Not found: fake expression Bool.to_py [Value { tag: "Bool", bits: 2 }]
@egraph.register
def _bool(x: Bool):
    return [
        set_(TRUE.to_py()).to(egraph.save_object(True)),
        set_(FALSE.to_py()).to(egraph.save_object(False)),
        rewrite(TRUE | x).to(TRUE),
        rewrite(FALSE | x).to(x),
    ]


bool(TRUE | FALSE)
TRUE | FALSE
   -> TRUE
True
x = egraph.constant("x", Bool)
if TRUE | x:
    print("it's true!")
TRUE | x
   -> TRUE
it's true!

Conversions#

…from Python to egglog…

converter(bool, Bool, lambda x: TRUE if x else FALSE)

TRUE | False
TRUE | FALSE
  • Allow you to do “upcasting” which is very common in Python

  • Can get closer to mimicking regular Python APIs

Fib Example#

egraph = EGraph()


@egraph.class_
class Num(Expr):
    def __init__(self, i: i64Like) -> None:
        ...

    def __add__(self, other: Num) -> Num:
        ...


@egraph.function
def fib(x: i64Like) -> Num:
    ...


@egraph.register
def _fib(a: i64, b: i64, x: i64, f: Num):
    return [
        rewrite(Num(a) + Num(b)).to(Num(a + b)),
        rule(eq(f).to(fib(x)), x > 1).then(set_(fib(x)).to(fib(x - 1) + fib(x - 2))),
        set_(fib(0)).to(Num(0)),
        set_(fib(1)).to(Num(1)),
    ]


egraph
outer_cluster_Num.__init___5871781006564002453_0 cluster_Num.__init___5871781006564002453_0 outer_cluster_fib_0_0 cluster_fib_0_0 outer_cluster_Num.__init___0_0 cluster_Num.__init___0_0 outer_cluster_fib_5871781006564002453_0 cluster_fib_5871781006564002453_0 outer_cluster_0 cluster_0 outer_cluster_1 cluster_1 fib_0:s->fib_0_0 Num.__init___0:s->Num.__init___0_0 fib_5871781006564002453:s->fib_5871781006564002453_0 Num.__init___5871781006564002453:s->Num.__init___5871781006564002453_0 Num.__init___5871781006564002453_0 1 fib_0_0 0 Num.__init___0_0 0 fib_5871781006564002453_0 1 fib_0 fib Num.__init___0 Num.__init__ fib_5871781006564002453 fib Num.__init___5871781006564002453 Num.__init__
f4 = egraph.let("f4", fib(4))
egraph
outer_cluster_fib_5871781006564002453_0 cluster_fib_5871781006564002453_0 outer_cluster_fib_0_0 cluster_fib_0_0 outer_cluster_fib_5040379952546458196_0 cluster_fib_5040379952546458196_0 outer_cluster_Num.__init___0_0 cluster_Num.__init___0_0 outer_cluster_Num.__init___5871781006564002453_0 cluster_Num.__init___5871781006564002453_0 outer_cluster_0 cluster_0 outer_cluster_2 cluster_2 outer_cluster_1 cluster_1 Num.__init___0:s->Num.__init___0_0 fib_0:s->fib_0_0 fib_5040379952546458196:s->fib_5040379952546458196_0 Num.__init___5871781006564002453:s->Num.__init___5871781006564002453_0 fib_5871781006564002453:s->fib_5871781006564002453_0 fib_5871781006564002453_0 1 fib_0_0 0 fib_5040379952546458196_0 4 Num.__init___0_0 0 Num.__init___5871781006564002453_0 1 Num.__init___0 Num.__init__ fib_0 fib fib_5040379952546458196 fib f4_0 f4 Num.__init___5871781006564002453 Num.__init__ fib_5871781006564002453 fib
egraph.run(1)
egraph
outer_cluster_fib_5871781006564002453_0 cluster_fib_5871781006564002453_0 outer_cluster_fib_0_0 cluster_fib_0_0 outer_cluster_fib_17615343019692007359_0 cluster_fib_17615343019692007359_0 outer_cluster_fib_5040379952546458196_0 cluster_fib_5040379952546458196_0 outer_cluster_fib_11743562013128004906_0 cluster_fib_11743562013128004906_0 outer_cluster_Num.__init___0_0 cluster_Num.__init___0_0 outer_cluster_Num.__init___5871781006564002453_0 cluster_Num.__init___5871781006564002453_0 outer_cluster_0 cluster_0 outer_cluster_1 cluster_1 outer_cluster_3 cluster_3 outer_cluster_2 cluster_2 outer_cluster_4 cluster_4 Num.__init___0:s->Num.__init___0_0 fib_0:s->fib_0_0 Num.__init___5871781006564002453:s->Num.__init___5871781006564002453_0 fib_5871781006564002453:s->fib_5871781006564002453_0 fib_17615343019692007359:s->fib_17615343019692007359_0 Num.__add___16275225025205966978:s->fib_17615343019692007359 Num.__add___16275225025205966978:s->fib_11743562013128004906 fib_5040379952546458196:s->fib_5040379952546458196_0 fib_11743562013128004906:s->fib_11743562013128004906_0 fib_5871781006564002453_0 1 fib_0_0 0 fib_17615343019692007359_0 3 fib_5040379952546458196_0 4 fib_11743562013128004906_0 2 Num.__init___0_0 0 Num.__init___5871781006564002453_0 1 Num.__init___0 Num.__init__ fib_0 fib Num.__init___5871781006564002453 Num.__init__ fib_5871781006564002453 fib fib_17615343019692007359 fib Num.__add___16275225025205966978 Num.__add__ fib_5040379952546458196 fib f4_0 f4 fib_11743562013128004906 fib
egraph.run(1)
egraph
outer_cluster_fib_5871781006564002453_0 cluster_fib_5871781006564002453_0 outer_cluster_fib_0_0 cluster_fib_0_0 outer_cluster_fib_5040379952546458196_0 cluster_fib_5040379952546458196_0 outer_cluster_fib_17615343019692007359_0 cluster_fib_17615343019692007359_0 outer_cluster_fib_11743562013128004906_0 cluster_fib_11743562013128004906_0 outer_cluster_Num.__init___0_0 cluster_Num.__init___0_0 outer_cluster_Num.__init___5871781006564002453_0 cluster_Num.__init___5871781006564002453_0 outer_cluster_0 cluster_0 outer_cluster_1 cluster_1 outer_cluster_4 cluster_4 outer_cluster_2 cluster_2 outer_cluster_3 cluster_3 Num.__init___0:s->Num.__init___0_0 fib_0:s->fib_0_0 Num.__init___5871781006564002453:s->Num.__init___5871781006564002453_0 fib_5871781006564002453:s->fib_5871781006564002453_0 Num.__add___956286968014291186:s->fib_0 Num.__add___956286968014291186:s->fib_5871781006564002453 fib_11743562013128004906:s->fib_11743562013128004906_0 Num.__add___16275225025205966978:s->fib_11743562013128004906 Num.__add___16275225025205966978:s->fib_17615343019692007359 fib_5040379952546458196:s->fib_5040379952546458196_0 Num.__add___6267377405668604861:s->Num.__init___5871781006564002453 Num.__add___6267377405668604861:s->Num.__add___956286968014291186 fib_17615343019692007359:s->fib_17615343019692007359_0 fib_5871781006564002453_0 1 fib_0_0 0 fib_5040379952546458196_0 4 fib_17615343019692007359_0 3 fib_11743562013128004906_0 2 Num.__init___0_0 0 Num.__init___5871781006564002453_0 1 Num.__init___0 Num.__init__ fib_0 fib Num.__init___5871781006564002453 Num.__init__ fib_5871781006564002453 fib Num.__add___956286968014291186 Num.__add__ fib_11743562013128004906 fib Num.__add___16275225025205966978 Num.__add__ fib_5040379952546458196 fib f4_0 f4 Num.__add___6267377405668604861 Num.__add__ fib_17615343019692007359 fib
egraph.run(1)
egraph
outer_cluster_fib_17615343019692007359_0 cluster_fib_17615343019692007359_0 outer_cluster_fib_0_0 cluster_fib_0_0 outer_cluster_fib_5040379952546458196_0 cluster_fib_5040379952546458196_0 outer_cluster_fib_5871781006564002453_0 cluster_fib_5871781006564002453_0 outer_cluster_fib_11743562013128004906_0 cluster_fib_11743562013128004906_0 outer_cluster_Num.__init___0_0 cluster_Num.__init___0_0 outer_cluster_Num.__init___5871781006564002453_0 cluster_Num.__init___5871781006564002453_0 outer_cluster_0 cluster_0 outer_cluster_4 cluster_4 outer_cluster_2 cluster_2 outer_cluster_3 cluster_3 Num.__init___0:s->Num.__init___0_0 fib_0:s->fib_0_0 Num.__add___395596399104602408:s->fib_0 Num.__add___395596399104602408:s->fib_11743562013128004906 Num.__init___5871781006564002453:s->Num.__init___5871781006564002453_0 fib_11743562013128004906:s->fib_11743562013128004906_0 fib_5871781006564002453:s->fib_5871781006564002453_0 Num.__add___16275225025205966978:s->fib_5871781006564002453 Num.__add___16275225025205966978:s->fib_17615343019692007359 fib_5040379952546458196:s->fib_5040379952546458196_0 Num.__add___5435976351651060604:s->Num.__add___395596399104602408 Num.__add___5435976351651060604:s->Num.__init___5871781006564002453 fib_17615343019692007359:s->fib_17615343019692007359_0 fib_17615343019692007359_0 3 fib_0_0 0 fib_5040379952546458196_0 4 fib_5871781006564002453_0 1 fib_11743562013128004906_0 2 Num.__init___0_0 0 Num.__init___5871781006564002453_0 1 Num.__init___0 Num.__init__ fib_0 fib Num.__add___395596399104602408 Num.__add__ Num.__init___5871781006564002453 Num.__init__ fib_11743562013128004906 fib fib_5871781006564002453 fib Num.__add___16275225025205966978 Num.__add__ fib_5040379952546458196 fib f4_0 f4 Num.__add___5435976351651060604 Num.__add__ fib_17615343019692007359 fib
egraph.run(1)
egraph
outer_cluster_fib_17615343019692007359_0 cluster_fib_17615343019692007359_0 outer_cluster_fib_11743562013128004906_0 cluster_fib_11743562013128004906_0 outer_cluster_Num.__init___0_0 cluster_Num.__init___0_0 outer_cluster_fib_5871781006564002453_0 cluster_fib_5871781006564002453_0 outer_cluster_fib_0_0 cluster_fib_0_0 outer_cluster_fib_5040379952546458196_0 cluster_fib_5040379952546458196_0 outer_cluster_Num.__init___5871781006564002453_0 cluster_Num.__init___5871781006564002453_0 outer_cluster_Num.__init___11743562013128004906_0 cluster_Num.__init___11743562013128004906_0 outer_cluster_0 cluster_0 outer_cluster_4 cluster_4 outer_cluster_2 cluster_2 outer_cluster_3 cluster_3 Num.__init___0:s->Num.__init___0_0 fib_0:s->fib_0_0 Num.__add___395596399104602408:s->fib_0 Num.__add___395596399104602408:s->fib_11743562013128004906 Num.__init___5871781006564002453:s->Num.__init___5871781006564002453_0 fib_11743562013128004906:s->fib_11743562013128004906_0 fib_5871781006564002453:s->fib_5871781006564002453_0 Num.__add___16275225025205966978:s->fib_5871781006564002453 Num.__add___16275225025205966978:s->fib_17615343019692007359 fib_5040379952546458196:s->fib_5040379952546458196_0 Num.__add___5435976351651060604:s->Num.__add___395596399104602408 Num.__add___5435976351651060604:s->Num.__init___5871781006564002453 Num.__init___11743562013128004906:s->Num.__init___11743562013128004906_0 fib_17615343019692007359:s->fib_17615343019692007359_0 fib_17615343019692007359_0 3 fib_11743562013128004906_0 2 Num.__init___0_0 0 fib_5871781006564002453_0 1 fib_0_0 0 fib_5040379952546458196_0 4 Num.__init___5871781006564002453_0 1 Num.__init___11743562013128004906_0 2 Num.__init___0 Num.__init__ fib_0 fib Num.__add___395596399104602408 Num.__add__ Num.__init___5871781006564002453 Num.__init__ fib_11743562013128004906 fib fib_5871781006564002453 fib Num.__add___16275225025205966978 Num.__add__ fib_5040379952546458196 fib f4_0 f4 Num.__add___5435976351651060604 Num.__add__ Num.__init___11743562013128004906 Num.__init__ fib_17615343019692007359 fib
egraph.run(1)
print(egraph.extract(f4))
egraph
Num(3)
outer_cluster_Num.__init___11743562013128004906_0 cluster_Num.__init___11743562013128004906_0 outer_cluster_fib_17615343019692007359_0 cluster_fib_17615343019692007359_0 outer_cluster_fib_11743562013128004906_0 cluster_fib_11743562013128004906_0 outer_cluster_Num.__init___0_0 cluster_Num.__init___0_0 outer_cluster_fib_5871781006564002453_0 cluster_fib_5871781006564002453_0 outer_cluster_fib_0_0 cluster_fib_0_0 outer_cluster_fib_5040379952546458196_0 cluster_fib_5040379952546458196_0 outer_cluster_Num.__init___5871781006564002453_0 cluster_Num.__init___5871781006564002453_0 outer_cluster_Num.__init___17615343019692007359_0 cluster_Num.__init___17615343019692007359_0 outer_cluster_0 cluster_0 outer_cluster_4 cluster_4 outer_cluster_2 cluster_2 outer_cluster_3 cluster_3 Num.__init___0:s->Num.__init___0_0 fib_0:s->fib_0_0 Num.__add___395596399104602408:s->fib_0 Num.__add___395596399104602408:s->fib_11743562013128004906 Num.__init___5871781006564002453:s->Num.__init___5871781006564002453_0 fib_11743562013128004906:s->fib_11743562013128004906_0 fib_5871781006564002453:s->fib_5871781006564002453_0 Num.__add___16275225025205966978:s->fib_5871781006564002453 Num.__add___16275225025205966978:s->fib_17615343019692007359 Num.__init___17615343019692007359:s->Num.__init___17615343019692007359_0 fib_5040379952546458196:s->fib_5040379952546458196_0 Num.__add___5435976351651060604:s->Num.__add___395596399104602408 Num.__add___5435976351651060604:s->Num.__init___5871781006564002453 Num.__init___11743562013128004906:s->Num.__init___11743562013128004906_0 fib_17615343019692007359:s->fib_17615343019692007359_0 Num.__init___11743562013128004906_0 2 fib_17615343019692007359_0 3 fib_11743562013128004906_0 2 Num.__init___0_0 0 fib_5871781006564002453_0 1 fib_0_0 0 fib_5040379952546458196_0 4 Num.__init___5871781006564002453_0 1 Num.__init___17615343019692007359_0 3 Num.__init___0 Num.__init__ fib_0 fib Num.__add___395596399104602408 Num.__add__ Num.__init___5871781006564002453 Num.__init__ fib_11743562013128004906 fib fib_5871781006564002453 fib Num.__add___16275225025205966978 Num.__add__ Num.__init___17615343019692007359 Num.__init__ fib_5040379952546458196 fib f4_0 f4 Num.__add___5435976351651060604 Num.__add__ Num.__init___11743562013128004906 Num.__init__ fib_17615343019692007359 fib

A story about Arrays#

  • This is one path through a huge maze of use cases.

  • Does not represent one killer example, but is an area I am familar with based on my previous work

1. Someone makes an NDArray library…#

ndarray_mod = Module()
...
@ndarray_mod.class_
class Value(Expr):
    def __init__(self, v: i64Like) -> None:
        ...

    def __mul__(self, other: Value) -> Value:
        ...

    def __add__(self, other: Value) -> Value:
        ...


i, j = vars_("i j", i64)
ndarray_mod.register(
    rewrite(Value(i) * Value(j)).to(Value(i * j)),
    rewrite(Value(i) + Value(j)).to(Value(i + j)),
)


@ndarray_mod.class_
class Values(Expr):
    def __init__(self, v: Vec[Value]) -> None:
        ...

    def __getitem__(self, idx: Value) -> Value:
        ...

    def length(self) -> Value:
        ...

    def concat(self, other: Values) -> Values:
        ...


@ndarray_mod.register
def _values(vs: Vec[Value], other: Vec[Value]):
    yield rewrite(Values(vs)[Value(i)]).to(vs[i])
    yield rewrite(Values(vs).length()).to(Value(vs.length()))
    yield rewrite(Values(vs).concat(Values(other))).to(Values(vs.append(other)))
@ndarray_mod.class_
class NDArray(Expr):
    def __getitem__(self, idx: Values) -> Value:
        ...

    def shape(self) -> Values:
        ...


@ndarray_mod.function
def arange(n: Value) -> NDArray:
    ...
  • Basic

  • One function, range, get shape and index into array

  • Very different from existing paradigms in Python… Inheritance, multi-dispatch, dunder protocols.

    • Entirely open protocol.

    • Anyone else could define ways to create arrays

    • About mathematical definition really. This is from M

Restifo Mullin, Lenore Marie, “A mathematics of arrays” (1988). Electrical Engineering and Computer Science - Dissertations. 249.

@ndarray_mod.register
def _(n: Value, idx: Values, a: NDArray):
    yield rewrite(arange(n).shape()).to(Values(Vec(n)))
    yield rewrite(arange(n)[idx]).to(idx[Value(0)])
  • Rules to compute shape and index into arange.

egraph = EGraph([ndarray_mod])
ten = egraph.let("ten", arange(Value(10)))
ten_shape = ten.shape()
egraph.register(ten_shape)

egraph.run(20)
egraph.display()
egraph.extract(ten_shape)
outer_cluster_Value.__init___3377577844511369682_0 cluster_Value.__init___3377577844511369682_0 outer_cluster_Values.__init___0_0 cluster_Values.__init___0_0 outer_cluster_2 cluster_2 outer_cluster_1 cluster_1 outer_cluster_0 cluster_0 Values.__init___0_0:s->Value.__init___3377577844511369682 Values.__init___0:s->Values.__init___0_0 NDArray.shape_5871781006564002453:s->ten_0 arange_0:s->Value.__init___3377577844511369682 Value.__init___3377577844511369682:s->Value.__init___3377577844511369682_0 Value.__init___3377577844511369682_0 10 Values.__init___0_0 Vec[Value] Values.__init___0 Values.__init__ NDArray.shape_5871781006564002453 NDArray.shape arange_0 arange ten_0 ten Value.__init___3377577844511369682 Value.__init__
Values(Vec.empty().push(Value(10)))
ten_indexed = ten[Values(Vec(Value(7)))]
egraph.register(ten_indexed)

egraph.run(20)

egraph.display()
egraph.extract(ten_indexed)
outer_cluster_Value.__init___3377577844511369682_0 cluster_Value.__init___3377577844511369682_0 outer_cluster_Value.__init___0_0 cluster_Value.__init___0_0 outer_cluster_Values.__init___11743562013128004906_0 cluster_Values.__init___11743562013128004906_0 outer_cluster_Value.__init___4208978898528913939_0 cluster_Value.__init___4208978898528913939_0 outer_cluster_Values.__init___0_0 cluster_Values.__init___0_0 outer_cluster_2 cluster_2 outer_cluster_5 cluster_5 outer_cluster_1 cluster_1 outer_cluster_0 cluster_0 outer_cluster_6 cluster_6 outer_cluster_7 cluster_7 Values.__init___0_0:s->Value.__init___3377577844511369682 Values.__init___11743562013128004906_0:s->NDArray.__getitem___11868447927124751835 Values.__init___0:s->Values.__init___0_0 NDArray.shape_5871781006564002453:s->arange_0 Values.__init___11743562013128004906:s->Values.__init___11743562013128004906_0 arange_0:s->Value.__init___3377577844511369682 Value.__init___3377577844511369682:s->Value.__init___3377577844511369682_0 Value.__init___4208978898528913939:s->Value.__init___4208978898528913939_0 Values.__getitem___520482313101349337:s->Values.__init___11743562013128004906 Values.__getitem___520482313101349337:s->Value.__init___0 NDArray.__getitem___11868447927124751835:s->Values.__init___11743562013128004906 NDArray.__getitem___11868447927124751835:s->ten_0 Value.__init___0:s->Value.__init___0_0 Value.__init___3377577844511369682_0 10 Value.__init___0_0 0 Value.__init___4208978898528913939_0 7 Values.__init___0_0 Vec[Value] Values.__init___11743562013128004906_0 Vec[Value] Values.__init___0 Values.__init__ NDArray.shape_5871781006564002453 NDArray.shape Values.__init___11743562013128004906 Values.__init__ arange_0 arange ten_0 ten Value.__init___3377577844511369682 Value.__init__ Value.__init___4208978898528913939 Value.__init__ Values.__getitem___520482313101349337 Values.__getitem__ NDArray.__getitem___11868447927124751835 NDArray.__getitem__ Value.__init___0 Value.__init__
Value(7)
  • Any user can try it now

2. Someone else decides to implement a cross product library#

cross_mod = Module([ndarray_mod])


@cross_mod.function
def cross(l: NDArray, r: NDArray) -> NDArray:
    ...


@cross_mod.register
def _cross(l: NDArray, r: NDArray, idx: Values):
    yield rewrite(cross(l, r).shape()).to(l.shape().concat(r.shape()))
    # Just noticed this is wrong!
    yield rewrite(cross(l, r)[idx]).to(l[idx] * r[idx])
  • Someone decides to add some functionality

  • Multiplicative cross product

  • Shape is concatation, index is product of each matrix at that index

  • Mathematical definition

egraph = EGraph([cross_mod])
egraph.simplify(cross(arange(Value(10)), arange(Value(11))).shape(), 10)
Values(Vec.empty().push(Value(11)).push(Value(10)))

3. I write my wonderful data science application using it#

def my_special_app(x: Value) -> Value:
    return cross(arange(x), arange(x))[Values(Vec(x))]


egraph = EGraph([cross_mod])

egraph.simplify(my_special_app(Value(10)), 10)
Value(100)
  • Different person installs cross module

  • Implements application using their complicated algorithm

…. but its too slow…

for i in range(100):
    egraph.simplify(my_special_app(Value(i)), 10)
  • Too slow in inner loop

  • Is there a way we could optimize it

4. Someone else writes a library for delayed execution#

py_mod = Module([ndarray_mod])


@py_mod.function
def py_value(s: StringLike) -> Value:
    ...


...
Ellipsis
  • While this is happening, someone else, based on the original module, wrote a different execution semantics

  • Builds up expression string instead of trying to evaluate eagerly

@py_mod.register
def _py_value(l: String, r: String):
    yield rewrite(py_value(l) + py_value(r)).to(py_value(join(l, " + ", r)))
    yield rewrite(py_value(l) * py_value(r)).to(py_value(join(l, " * ", r)))


@py_mod.function
def py_values(s: StringLike) -> Values:
    ...


@py_mod.register
def _py_values(l: String, r: String):
    yield rewrite(py_values(l)[py_value(r)]).to(py_value(join(l, "[", r, "]")))
    yield rewrite(py_values(l).length()).to(py_value(join("len(", l, ")")))
    yield rewrite(py_values(l).concat(py_values(r))).to(py_values(join(l, " + ", r)))


@py_mod.function
def py_ndarray(s: StringLike) -> NDArray:
    ...


@py_mod.register
def _py_ndarray(l: String, r: String):
    yield rewrite(py_ndarray(l)[py_values(r)]).to(py_value(join(l, "[", r, "]")))
    yield rewrite(py_ndarray(l).shape()).to(py_values(join(l, ".shape")))
    yield rewrite(arange(py_value(l))).to(py_ndarray(join("np.arange(", l, ")")))

5. I can use it jit compile my application!#

egraph = EGraph([cross_mod, py_mod])
egraph.simplify(my_special_app(py_value("x")), 10)
py_value("x * x")
  • I pull in third party library

  • Add it to my e-graph

  • Now I can compile lazily

  • py_mod never needed to know about cross product, works with it

… and add support for jit compilation for the other library I am using, without changing either library:

@egraph.register
def _(l: String, r: String):
    yield rewrite(cross(py_ndarray(l), py_ndarray(r))).to(
        py_ndarray(join("np.multiply.outer(", l, ", ", r, ")"))
    )
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[42], line 2
      1 @egraph.register
----> 2 def _(l: String, r: String):
      3     yield rewrite(cross(py_ndarray(l), py_ndarray(r))).to(py_ndarray(join("np.multiply.outer(", l, ", ", r, ")")))

File ~/p/egg-smol-python/python/egglog/egraph.py:507, in _BaseModule.register(self, command_or_generator, *commands)
    505 else:
    506     commands = (cast(CommandLike, command_or_generator), *commands)
--> 507 self._process_commands(_command_like(command)._to_egg_command() for command in commands)

File ~/p/egg-smol-python/python/egglog/egraph.py:634, in EGraph._process_commands(self, commands)
    633 def _process_commands(self, commands: Iterable[bindings._Command]) -> None:
--> 634     self._egraph.run_program(*commands)

File ~/p/egg-smol-python/python/egglog/egraph.py:507, in <genexpr>(.0)
    505 else:
    506     commands = (cast(CommandLike, command_or_generator), *commands)
--> 507 self._process_commands(_command_like(command)._to_egg_command() for command in commands)

File ~/p/egg-smol-python/python/egglog/egraph.py:907, in Rewrite._to_egg_command(self)
    906 def _to_egg_command(self) -> bindings._Command:
--> 907     return bindings.RewriteCommand(self._ruleset, self._to_egg_rewrite())

File ~/p/egg-smol-python/python/egglog/egraph.py:911, in Rewrite._to_egg_rewrite(self)
    909 def _to_egg_rewrite(self) -> bindings.Rewrite:
    910     return bindings.Rewrite(
--> 911         self._lhs.__to_egg__(),
    912         self._rhs.__to_egg__(),
    913         [c._to_egg_fact() for c in self._conditions],
    914     )

File ~/p/egg-smol-python/python/egglog/runtime.py:370, in RuntimeExpr.__to_egg__(self)
    369 def __to_egg__(self) -> bindings._Expr:
--> 370     return self.__egg_typed_expr__.expr.to_egg(self.__egg_decls__)

File ~/p/egg-smol-python/python/egglog/declarations.py:582, in CallDecl.to_egg(self, mod_decls)
    580 """Convert a Call to an egg Call."""
    581 egg_fn = mod_decls.get_egg_fn(self.callable)
--> 582 return bindings.Call(egg_fn, [a.to_egg(mod_decls) for a in self.args])

File ~/p/egg-smol-python/python/egglog/declarations.py:582, in <listcomp>(.0)
    580 """Convert a Call to an egg Call."""
    581 egg_fn = mod_decls.get_egg_fn(self.callable)
--> 582 return bindings.Call(egg_fn, [a.to_egg(mod_decls) for a in self.args])

File ~/p/egg-smol-python/python/egglog/declarations.py:681, in TypedExprDecl.to_egg(self, decls)
    680 def to_egg(self, decls: ModuleDeclarations) -> bindings._Expr:
--> 681     return self.expr.to_egg(decls)

File ~/p/egg-smol-python/python/egglog/declarations.py:581, in CallDecl.to_egg(self, mod_decls)
    579 def to_egg(self, mod_decls: ModuleDeclarations) -> bindings.Call:
    580     """Convert a Call to an egg Call."""
--> 581     egg_fn = mod_decls.get_egg_fn(self.callable)
    582     return bindings.Call(egg_fn, [a.to_egg(mod_decls) for a in self.args])

File ~/p/egg-smol-python/python/egglog/declarations.py:205, in ModuleDeclarations.get_egg_fn(self, ref)
    203     except KeyError:
    204         pass
--> 205 raise KeyError(f"Callable ref {ref} not found")

KeyError: "Callable ref FunctionRef(name='py_ndarray') not found"
egraph.run(20)
egraph.graphviz().render(outfile="big_graph.svg", format="svg")

Takeaways…#

…from this totally realistic example.

  • Declerative nature of egglog could facilitate decentralized library collaboration and experimentation.

    • Focus on types over values for library authors encourages interoperability.

  • Pushing power down, empowering users and library authors

  • Could allow greater collaboration between PL community and data science library community in Python

Arrays in the “Real World”#

What would it take to make this example work with egglog?

from sklearn import datasets
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn import config_context

iris = datasets.load_iris()

X = iris.data
y = iris.target


def fit(X, y):
    with config_context(array_api_dispatch=True):
        lda = LinearDiscriminantAnalysis(n_components=2)
        X_r2 = lda.fit(X, y).transform(X)
        return X_r2


fit(X, y)[:5]
array([[ 8.06179978, -0.30042062],
       [ 7.12868772,  0.78666043],
       [ 7.48982797,  0.26538449],
       [ 6.81320057,  0.67063107],
       [ 8.13230933, -0.51446253]])

Could we execute this symbolically?

@egraph.class_
class NDArray(Expr):
    @classmethod
    def var(cls, name: StringLike) -> NDArray:
        ...

    ...
X_arr = NDArray.var("X")
y_arr =  NDArray.var("y")
fit(X_arr, y_arr)

Started working on this yesterday…

Provide egglog with metadata at least about the types, to get through sklearn’s sanity checks (which need to be executed eagerly):

egraph.register(
    rewrite(X_arr.dtype).to(convert(X.dtype, DType)),
    rewrite(y_arr.dtype).to(convert(y.dtype, DType)),
    rewrite(isfinite(sum(X_arr)).bool()).to(TRUE),
    rewrite(isfinite(sum(y_arr)).bool()).to(TRUE),
    rewrite(X_arr.shape).to(convert(X.shape, TupleInt)),
    rewrite(y_arr.shape).to(convert(y.shape, TupleInt)),
    rewrite(X_arr.size).to(Int(X.size)),
    rewrite(y_arr.size).to(Int(y.size)),
    rewrite(unique_values(y_arr).shape).to(TupleInt(Int(3)))
)

Define all the required Array API functions:

@egraph.function
def reshape(x: NDArray, shape: TupleInt, copy: OptionalBool = OptionalBool.none) -> NDArray:
    ...
@egraph.register
def _reshape(x: NDArray, y: NDArray, shape: TupleInt, copy: OptionalBool, i: Int, s: String):
    return [
        # dtype of result is same as input
        rewrite(
            reshape(x, shape, copy).dtype
        ).to(x.dtype),
        # dimensions of output are the same as length of shape
        rewrite(
            reshape(x, shape, copy).shape.length()
        ).to(shape.length()),
        # Shape of single dimensions reshape is the # elements
        rewrite(
            reshape(x, TupleInt(Int(-1)), copy).shape
        ).to(TupleInt(x.size)),
        # Reshaping one dimension no-op
        rule(
            eq(y).to(reshape(x, TupleInt(Int(-1)), copy)),
            eq(x.shape).to(TupleInt(i)),
        ).then(
            union(x).with_(y)
        )
    ]

Can see some examples of rewrites executing during sklearns checking:

asarray(reshape(asarray(NDArray.var("y")), (TupleInt(Int(-1)) + TupleInt.EMPTY))).shape[Int(0)] == asarray(NDArray.var("X")).shape[Int(0)]
  -> NDArray.var("y").size == NDArray.var("X").shape[Int(0)]
     -> TRUE

asarray(asarray(reshape(asarray(NDArray.var("y")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).ndim == Int(2)
  -> FALSE

That’s as far as I got!

Conclusion#

  • e-graphs are a data structure we can use to build term rewriting systems

  • egglog is a language, and Python library, for building e-graphs

  • Looking forward to seeing how it might be used in PyData ecosystem

pip install egglog

Welcome new contributations, experiments, and conversations…

Come say hello at github.com/egraphs-good/egglog-python ad egraphs.zulipchat.com!