03 - E-class Analysis#
This tutorial is translated from egglog.
Datalog is a relational language for deductive reasoning. In the last lesson, we write our first equality saturation program in egglog, but you can also write rules for deductive reasoning a la Datalog. In this lesson, we will write several classic Datalog programs in egglog. One of the benifits of egglog being a language for program optimization is that it can talk about terms natively, so in egglog we get Datalog with terms for free.
In this lesson, we learn how to combine the power of equality saturation and Datalog. We will show how we can define program analyses using Datalog-style deductive reasoning, how EqSat-style rewrite rules can make the program analyses more accurate, and how accurate program analyses can enable more powerful rewrites.
Our first example will continue with the path example in lesson 2.
In this case, there is a path from e1 to e2 if e1 is less than or equal to e2.
# mypy: disable-error-code="empty-body"
from __future__ import annotations
from collections.abc import Iterable
from typing import TypeAlias
from egglog import *
class Num(Expr):
# in this example we use big 🐀 to represent numbers
# you can find a list of primitive types in the standard library in [`builtins.py`](https://github.com/egraphs-good/egglog-python/blob/main/python/egglog/builtins.py)
def __init__(self, value: BigRatLike) -> None: ...
@classmethod
def var(cls, name: StringLike) -> Num: ...
def __add__(self, other: NumLike) -> Num: ...
def __radd__(self, other: NumLike) -> Num: ...
def __mul__(self, other: NumLike) -> Num: ...
def __rmul__(self, other: NumLike) -> Num: ...
def __truediv__(self, other: NumLike) -> Num: ...
def __le__(self, other: NumLike) -> Unit: ...
@property
def non_zero(self) -> Unit: ...
NumLike: TypeAlias = Num | StringLike | BigRatLike
converter(BigRat, Num, Num)
converter(String, Num, Num.var)
Let’s define some BigRat constants that will be useful later.
zero = BigRat(0, 1)
one = BigRat(1, 1)
two = BigRat(2, 1)
We define a less-than-or-equal-to relation between two expressions.
a.__le__(b) means that a <= b for all possible values of variables.
We define rules to deduce the le relation.
egraph = EGraph()
@egraph.register
def _(
e1: Num, e2: Num, e3: Num, n1: BigRat, n2: BigRat, x: String, e1a: Num, e1b: Num, e2a: Num, e2b: Num
) -> Iterable[RewriteOrRule]:
# We start with transitivity of `<=`:
yield rule(e1 <= e2, e2 <= e3).then(e1 <= e3)
# Base case for for `Num`:
yield rule(e1 == Num(n1), e2 == Num(n2), n1 <= n2).then(e1 <= e2)
# Base case for `Var`:`
yield rule(e1 == Num.var(x)).then(e1 <= e1) # noqa: PLR0124
# Recursive case for `Add`:
yield rule(
e1 == (e1a + e1b),
e2 == (e2a + e2b),
e1a <= e2a,
e1b <= e2b,
).then(e1 <= e2)
Note that we have not defined any rules for multiplication. This would require a more complex analysis on the positivity of the expressions.
On the other hand, these rules by themselves are pretty weak. For example, they cannot deduce x + 1 <= 2 + x.
But EqSat-style axiomatic rules make these rules more powerful:
@egraph.register
def _(x: Num, y: Num, z: Num, a: BigRat, b: BigRat) -> Iterable[RewriteOrRule]:
yield birewrite(x + (y + z)).to((x + y) + z)
yield birewrite(x * (y * z)).to((x * y) * z)
yield rewrite(x + y).to(y + x)
yield rewrite(x * y).to(y * x)
yield rewrite(x * (y + z)).to((x * y) + (x * z))
yield rewrite(x + zero).to(x)
yield rewrite(x * one).to(x)
yield rewrite(Num(a) + Num(b)).to(Num(a + b))
yield rewrite(Num(a) * Num(b)).to(Num(a * b))
To check our rules
expr1 = egraph.let("expr1", Num.var("y") + (Num(two) + "x"))
expr2 = egraph.let("expr2", Num.var("x") + Num.var("y") + Num(one) + Num(two))
egraph.check_fail(expr1 <= expr2)
egraph.run(run().saturate())
egraph.check(expr1 <= expr2)
egraph
A useful special case of the <= analysis is if an expression is upper bounded or lower bounded by certain numbers, i.e., interval analysis:
@function(merge=lambda old, new: old.min(new))
def upper_bound(e: Num) -> BigRat: ...
@function(merge=lambda old, new: old.max(new))
def lower_bound(e: Num) -> BigRat: ...
In the above functions, unlike <=, we define upper bound and lower bound as functions from
expressions to a unique number.
This is because we are always interested in the tightest upper bound
and lower bounds, so
@egraph.register
def _(e: Num, n: BigRat) -> Iterable[RewriteOrRule]:
yield rule(e <= Num(n)).then(set_(upper_bound(e)).to(n))
yield rule(Num(n) <= e).then(set_(lower_bound(e)).to(n))
We can define more specific rules for obtaining the upper and lower bounds of an expression based on the upper and lower bounds of its children.
@egraph.register
def _(e: Num, e1: Num, e2: Num, u1: BigRat, u2: BigRat, l1: BigRat, l2: BigRat) -> Iterable[RewriteOrRule]:
yield rule(
e == (e1 + e2),
upper_bound(e1) == u1,
upper_bound(e2) == u2,
).then(set_(upper_bound(e)).to(u1 + u2))
yield rule(
e == (e1 + e2),
lower_bound(e1) == l1,
lower_bound(e2) == l2,
).then(set_(lower_bound(e)).to(l1 + l2))
# ... and the giant rule for multiplication:
yield rule(
e == (e1 * e2),
l1 == lower_bound(e1),
l2 == lower_bound(e2),
u1 == upper_bound(e1),
u2 == upper_bound(e2),
).then(
set_(lower_bound(e)).to((l1 * l2).min((l1 * u2).min((u1 * l2).min(u1 * u2)))),
set_(upper_bound(e)).to((l1 * l2).max((l1 * u2).max((u1 * l2).max(u1 * u2)))),
)
# Similarly,
yield rule(e == e1 * e1).then(set_(lower_bound(e)).to(zero))
The interval analysis is not only useful for numerical tools like Herbie, but it can also guard certain optimization rules, making EqSat-based rewriting more powerful!
For example, we are interested in non-zero expressions
@egraph.register
def _(e: Num, e2: Num) -> Iterable[RewriteOrRule]:
yield rule(lower_bound(e) > zero).then(e.non_zero)
yield rule(upper_bound(e) < zero).then(e.non_zero)
yield rewrite(e / e).to(Num(one), e.non_zero)
yield rewrite(e * (e2 / e)).to(e2, e.non_zero)
This non-zero analysis lets us optimize expressions that contain division safely. 2 * (x / (1 + 2 / 2)) is equivalent to x
expr3 = egraph.let("expr3", Num(two) * (Num.var("x") / (Num(one) + (Num(two) / Num(two)))))
expr4 = egraph.let("expr4", Num.var("x"))
egraph.check_fail(expr3 == expr4)
egraph.run(run().saturate())
egraph.check(expr3 == expr4)
(x + 1)^2 + 2
expr5 = egraph.let("expr5", (Num.var("x") + Num(one)) * (Num.var("x") + Num(one)) + Num(two))
expr6 = egraph.let("expr6", expr5 / expr5)
egraph.run(run().saturate())
egraph.check(expr6 == Num(one))
Debugging tips!#
function_size is used to return the size of a table and all_function_sizes for to return the size of every table.
This is useful for debugging performance, by seeing how the table sizes evolve as the iteration count increases.
egraph.function_size(Num.__le__)
38
egraph.all_function_sizes()
[(· + ·, 2150),
(Num, 4),
(· <= ·, 38),
(· * ·, 168),
(· / ·, 3),
(·.non_zero, 28),
(Num.var, 2),
(lower_bound, 39),
(upper_bound, 4)]
function_values extracts every instance of a constructor, function, or relation in the e-graph.
It takes the maximum number of instances to extract as a second argument, so as not to spend time
printing millions of rows. function_values is particularly useful when debugging small e-graphs.
list(egraph.function_values(Num.__le__, 15))
[Num(BigRat(BigInt.from_string("2"), BigInt.from_string("1"))) <= Num(BigRat(BigInt.from_string("2"), BigInt.from_string("1"))),
Num(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))) <= Num(BigRat(BigInt.from_string("2"), BigInt.from_string("1"))),
Num(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))) <= Num(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))),
Num.var("y") <= Num.var("y"),
Num.var("x") <= Num.var("x"),
_Num_1 = Num(BigRat(BigInt.from_string("2"), BigInt.from_string("1"))) + Num.var("x")
_Num_1 <= _Num_1,
_Num_1 = Num.var("y") + Num(BigRat(BigInt.from_string("2"), BigInt.from_string("1")))
_Num_1 <= _Num_1,
Num.var("y") + Num(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))) <= Num.var("y") + Num(BigRat(BigInt.from_string("2"), BigInt.from_string("1"))),
_Num_1 = Num.var("y") + Num(BigRat(BigInt.from_string("1"), BigInt.from_string("1")))
_Num_1 <= _Num_1,
Num.var("x") + Num.var("y") <= Num.var("x") + Num.var("y"),
Num(BigRat(BigInt.from_string("3"), BigInt.from_string("1"))) <= Num(BigRat(BigInt.from_string("3"), BigInt.from_string("1"))),
Num(BigRat(BigInt.from_string("2"), BigInt.from_string("1"))) <= Num(BigRat(BigInt.from_string("3"), BigInt.from_string("1"))),
Num(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))) <= Num(BigRat(BigInt.from_string("3"), BigInt.from_string("1"))),
_Num_1 = Num.var("y") + (Num(BigRat(BigInt.from_string("2"), BigInt.from_string("1"))) + Num.var("x"))
_Num_1 <= _Num_1,
_Num_1 = Num.var("x") + Num.var("y") + Num(BigRat(BigInt.from_string("1"), BigInt.from_string("1")))
_Num_1 <= _Num_1]
extract_multiple can also be used to extract that many different “variants” of the
first argument. This is useful when trying to figure out why one e-class is failing to be unioned with another.
egraph.extract_multiple(expr3, 3)
[Num.var("x"),
Num.var("x") * Num(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))),
Num(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))) * Num.var("x")]