Note
Go to the end to download the full example code.
Higher Order Functions#
EGraph()
from __future__ import annotations
from collections.abc import Callable
from egglog import *
class Math(Expr):
def __init__(self, i: i64Like) -> None: ...
def __add__(self, other: Math) -> Math: ...
class MathList(Expr):
def __init__(self) -> None: ...
def append(self, i: Math) -> MathList: ...
def map(self, f: Callable[[Math], Math]) -> MathList: ...
@ruleset
def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]):
yield rewrite(Math(i) + Math(j)).to(Math(i + j))
yield rewrite(xs.append(x).map(f)).to(xs.map(f).append(f(x)))
yield rewrite(MathList().map(f)).to(MathList())
@function(ruleset=math_ruleset)
def incr_list(xs: MathList) -> MathList:
return xs.map(lambda x: x + Math(1))
egraph = EGraph()
y = egraph.let("y", incr_list(MathList().append(Math(1)).append(Math(2))))
egraph.run(math_ruleset.saturate())
egraph.check(eq(y).to(MathList().append(Math(2)).append(Math(3))))
egraph
Total running time of the script: (0 minutes 0.006 seconds)