# mypy: disable-error-code="empty-body"
"""
Multiset example based off of egglog version
============================================
"""
'\nMultiset example based off of egglog version\n============================================\n'
from __future__ import annotations
from egglog import *
class Math(Expr):
    def __init__(self, x: i64Like) -> None: ...
    def __add__(self, other: MathLike) -> Math: ...
    def __radd__(self, other: MathLike) -> Math: ...
    def __mul__(self, other: MathLike) -> Math: ...
    def __rmul__(self, other: MathLike) -> Math: ...
MathLike = Math | i64Like
converter(i64, Math, Math)
@function
def sum(xs: MultiSetLike[Math, MathLike]) -> Math: ...
@function
def product(xs: MultiSetLike[Math, MathLike]) -> Math: ...
@function
def square(x: Math) -> Math: ...
x = constant("x", Math)
expr1 = 2 * (x + 3)
expr2 = 6 + 2 * x
@ruleset
def math_ruleset(a: Math, b: Math, c: Math, i: i64, j: i64, xs: MultiSet[Math], ys: MultiSet[Math], zs: MultiSet[Math]):
    yield rewrite(a + b).to(sum(MultiSet(a, b)))
    yield rewrite(a * b).to(product(MultiSet(a, b)))
    # 0 or 1 elements sums/products also can be extracted back to numbers
    yield rule(a == sum(xs), xs.length() == i64(1)).then(a == xs.pick())
    yield rule(a == product(xs), xs.length() == i64(1)).then(a == xs.pick())
    yield rewrite(sum(MultiSet[Math]())).to(Math(0))
    yield rewrite(product(MultiSet[Math]())).to(Math(1))
    # distributive rule (a * (b + c) = a*b + a*c)
    yield rule(
        b == product(ys),
        a == sum(xs),
        ys.contains(a),
        ys.length() > 1,
        zs == ys.remove(a),
    ).then(
        b == sum(xs.map(lambda x: product(zs.insert(x)))),
    )
    # constants
    yield rule(
        a == sum(xs),
        b == Math(i),
        xs.contains(b),
        ys == xs.remove(b),
        c == Math(j),
        ys.contains(c),
    ).then(
        a == sum(ys.remove(c).insert(Math(i + j))),
    )
    yield rule(
        a == product(xs),
        b == Math(i),
        xs.contains(b),
        ys == xs.remove(b),
        c == Math(j),
        ys.contains(c),
    ).then(
        a == product(ys.remove(c).insert(Math(i * j))),
    )
egraph = EGraph()
egraph.register(expr1, expr2)
egraph.run(math_ruleset.saturate())
egraph.check(expr1 == expr2)