Note
Go to the end to download the full example code.
N-Dimensional Arrays#
Example of building NDarray in the vein of Mathemetics of Arrays.
arange(Value(10)).shape() == Values(Vec(Value(10))) ➡ Values(Vec[Value](Value(10)))
arange(Value(10))[Values(Vec(Value(0)))] == Value(0) ➡ Value(0)
arange(Value(10))[Values(Vec(Value(1)))] == Value(1) ➡ Value(1)
py_ndarray("x").shape() == py_values("x.shape") ➡ py_values("x.shape")
arange(py_value("x"))[py_values("y")] == py_value("np.arange(x)[y]") ➡ py_value("np.arange(x)[y]")
cross(arange(Value(10)), arange(Value(11))).shape() == Values(Vec(Value(10), Value(11))) ➡ Values(Vec[Value](Value(10), Value(11)))
cross(py_ndarray("x"), py_ndarray("y")).shape() == py_values("x.shape + y.shape") ➡ py_values("x.shape + y.shape")
cross(py_ndarray("x"), py_ndarray("y"))[py_values("idx")] == py_value("x[idx] * y[idx]") ➡ py_value("x[idx] * y[idx]")
cross(py_ndarray("x"), py_ndarray("y"))[py_values("idx")] == py_value("np.multiply.outer(x, y)[idx]") ➡ py_value("x[idx] * y[idx]")
from __future__ import annotations
from egglog import *
egraph = EGraph()
class Value(Expr):
def __init__(self, v: i64Like) -> None: ...
def __mul__(self, other: Value) -> Value: ...
def __add__(self, other: Value) -> Value: ...
i, j = vars_("i j", i64)
egraph.register(
rewrite(Value(i) * Value(j)).to(Value(i * j)),
rewrite(Value(i) + Value(j)).to(Value(i + j)),
)
class Values(Expr):
def __init__(self, v: Vec[Value]) -> None: ...
def __getitem__(self, idx: Value) -> Value: ...
def length(self) -> Value: ...
def concat(self, other: Values) -> Values: ...
@egraph.register
def _values(vs: Vec[Value], other: Vec[Value]):
yield rewrite(Values(vs)[Value(i)]).to(vs[i])
yield rewrite(Values(vs).length()).to(Value(vs.length()))
yield rewrite(Values(vs).concat(Values(other))).to(Values(vs.append(other)))
# yield rewrite(l.concat(r).length()).to(l.length() + r.length())
# yield rewrite(l.concat(r)[idx])
class NDArray(Expr):
"""
An n-dimensional array.
"""
def __getitem__(self, idx: Values) -> Value: ...
def shape(self) -> Values: ...
@function
def arange(n: Value) -> NDArray: ...
@egraph.register
def _ndarray_arange(n: Value, idx: Values):
yield rewrite(arange(n).shape()).to(Values(Vec(n)))
yield rewrite(arange(n)[idx]).to(idx[Value(0)])
def assert_simplifies(left: Expr, right: Expr) -> None:
"""
Simplify and print
"""
egraph.register(left)
egraph.run(30)
res = egraph.extract(left)
print(f"{left} == {right} ➡ {res}")
egraph.check(eq(left).to(right))
assert_simplifies(arange(Value(10)).shape(), Values(Vec(Value(10))))
assert_simplifies(arange(Value(10))[Values(Vec(Value(0)))], Value(0))
assert_simplifies(arange(Value(10))[Values(Vec(Value(1)))], Value(1))
@function
def py_value(s: StringLike) -> Value: ...
@egraph.register
def _py_value(l: String, r: String):
yield rewrite(py_value(l) + py_value(r)).to(py_value(join(l, " + ", r)))
yield rewrite(py_value(l) * py_value(r)).to(py_value(join(l, " * ", r)))
@function
def py_values(s: StringLike) -> Values: ...
@egraph.register
def _py_values(l: String, r: String):
yield rewrite(py_values(l)[py_value(r)]).to(py_value(join(l, "[", r, "]")))
yield rewrite(py_values(l).length()).to(py_value(join("len(", l, ")")))
yield rewrite(py_values(l).concat(py_values(r))).to(py_values(join(l, " + ", r)))
@function
def py_ndarray(s: StringLike) -> NDArray: ...
@egraph.register
def _py_ndarray(l: String, r: String):
yield rewrite(py_ndarray(l)[py_values(r)]).to(py_value(join(l, "[", r, "]")))
yield rewrite(py_ndarray(l).shape()).to(py_values(join(l, ".shape")))
yield rewrite(arange(py_value(l))).to(py_ndarray(join("np.arange(", l, ")")))
assert_simplifies(py_ndarray("x").shape(), py_values("x.shape"))
assert_simplifies(arange(py_value("x"))[py_values("y")], py_value("np.arange(x)[y]"))
# assert_simplifies(arange(py_value("x"))[py_values("y")], py_value("y[0]"))
@function
def cross(l: NDArray, r: NDArray) -> NDArray: ...
@egraph.register
def _cross(l: NDArray, r: NDArray, idx: Values):
yield rewrite(cross(l, r).shape()).to(l.shape().concat(r.shape()))
yield rewrite(cross(l, r)[idx]).to(l[idx] * r[idx])
assert_simplifies(cross(arange(Value(10)), arange(Value(11))).shape(), Values(Vec(Value(10), Value(11))))
assert_simplifies(cross(py_ndarray("x"), py_ndarray("y")).shape(), py_values("x.shape + y.shape"))
assert_simplifies(cross(py_ndarray("x"), py_ndarray("y"))[py_values("idx")], py_value("x[idx] * y[idx]"))
@egraph.register
def _cross_py(l: String, r: String):
yield rewrite(cross(py_ndarray(l), py_ndarray(r))).to(py_ndarray(join("np.multiply.outer(", l, ", ", r, ")")))
assert_simplifies(cross(py_ndarray("x"), py_ndarray("y"))[py_values("idx")], py_value("np.multiply.outer(x, y)[idx]"))
Total running time of the script: (0 minutes 0.052 seconds)