mypy: disable-error-code=“empty-body”

"""
Join Tree (custom costs)
========================

Example of using custom cost functions for jointree.

From https://egraphs.zulipchat.com/#narrow/stream/328972-general/topic/How.20can.20I.20find.20the.20tree.20associated.20with.20an.20extraction.3F
"""
'\nJoin Tree (custom costs)\n========================\n\nExample of using custom cost functions for jointree.\n\nFrom https://egraphs.zulipchat.com/#narrow/stream/328972-general/topic/How.20can.20I.20find.20the.20tree.20associated.20with.20an.20extraction.3F\n'
from __future__ import annotations
from egglog import *
class JoinTree(Expr):
    def __init__(self, name: StringLike) -> None: ...

    def join(self, other: JoinTree) -> JoinTree: ...

    @method(merge=lambda old, new: old.min(new))  # type:ignore[prop-decorator]
    @property
    def size(self) -> i64: ...
ra = JoinTree("a")
rb = JoinTree("b")
rc = JoinTree("c")
rd = JoinTree("d")
re = JoinTree("e")
rf = JoinTree("f")
query = ra.join(rb).join(rc).join(rd).join(re).join(rf)
egraph = EGraph()
egraph.register(
    set_(ra.size).to(50),
    set_(rb.size).to(200),
    set_(rc.size).to(10),
    set_(rd.size).to(123),
    set_(re.size).to(10000),
    set_(rf.size).to(1),
)
@egraph.register
def _rules(s: String, a: JoinTree, b: JoinTree, c: JoinTree, asize: i64, bsize: i64):
    # cost of relation is its size minus 1, since the string arg will have a cost of 1 as well
    yield rule(JoinTree(s).size == asize).then(set_cost(JoinTree(s), asize - 1))
    # cost/size of join is product of sizes
    yield rule(a.join(b), a.size == asize, b.size == bsize).then(
        set_(a.join(b).size).to(asize * bsize), set_cost(a.join(b), asize * bsize)
    )
    # associativity
    yield rewrite(a.join(b)).to(b.join(a))
    # commutativity
    yield rewrite(a.join(b).join(c)).to(a.join(b.join(c)))
egraph.register(query)
egraph.run(1000)
print(egraph.extract(query))
print(egraph.extract(query.size))
JoinTree("a").join(JoinTree("e")).join(JoinTree("b").join(JoinTree("d").join(JoinTree("c").join(JoinTree("f")))))
i64(123000000000)