{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "```{post} 2023-11-17\n", "\n", "```\n", "\n", "# [PyTensor](https://github.com/pymc-devs/pytensor) Chat\n", "\n", "Ricardo Vieira reached out asking to see if we could chat about egglog and to explore if it could be used inside\n", "of PyTensor for rewriting.\n", "\n", "We set up a call and he aggreed to record it, so that we could share anything we talked about with others:\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [ "hide-input" ] }, "outputs": [ { "data": { "image/jpeg": "", "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from IPython.display import YouTubeVideo\n", "\n", "YouTubeVideo(\"8rb841pBhf0\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It was great to get some feedback on the Python bindings and see where the rough patches are. So thank you Ricardo for\n", "being game to explore this together!\n", "\n", "Some of the main takeaways for me where:\n", "\n", "- Having generic user defined sorts in egglog would be very useful, so that every library like PyTensor doesn't\n", " have to reimplement collection types for every sort. If we had them, we could say implement a `egglog.std.Tuple`\n", " class that would work like a tuple, and if you had a user defined `Int` class, you could do `Tuple[Int]`.\n", "- It was interested to see how Ricardo started implementing the Op types at the end, as custom classes, and translating\n", " the pure functions to them. It's a nice example of how you can write multiple interfaces, depending on the user,\n", " and right rewrites in whichever you are more comfortable with, as long as you can convert to/from them.\n", "\n", "Some further things we could explore in the future are:\n", "\n", "- implementing loop nesting in egglog\n", "- converting between the existing PyTensor types and egglog types, in a programatic way, so that we could play with rewrites\n", " without having to rewrite their whole project.\n", "\n", "If anyone else who works on a Python library thinks they could benefit from egglog, or have other questions, feel free\n", "to reach out!\n", "\n", "A cleaned up version of the notebook is below:\n", "\n", "---\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZmFq-i9By7V4" }, "outputs": [], "source": [ "%%capture\n", "!pip install egglog\n", "!pip install anywidget" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "itgwmxIGy929" }, "outputs": [], "source": [ "from __future__ import annotations\n", "from typing import ClassVar, Tuple\n", "\n", "from functools import partial\n", "from egglog import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fBDfQ7Y5y_8d" }, "outputs": [], "source": [ "import numpy as np\n", "import pytensor\n", "import pytensor.tensor as pt" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xWZfT2oq8a3U" }, "outputs": [], "source": [ "from traitlets.traitlets import Type\n", "\n", "egraph = EGraph()\n", "\n", "\n", "@egraph.class_\n", "class Int(Expr):\n", " def __init__(self, value: i64Like) -> None:\n", " ...\n", "\n", " @classmethod\n", " def var(cls, name: StringLike) -> Int:\n", " ...\n", "\n", "\n", "converter(i64, Int, Int)\n", "\n", "\n", "@egraph.class_\n", "class IntTuple(Expr):\n", " def __init__(self, first: Int) -> None:\n", " ...\n", "\n", " @classmethod\n", " def empty(cls) -> IntTuple:\n", " ...\n", "\n", " def __add__(self, other: IntTuple) -> IntTuple:\n", " ...\n", "\n", " def length(self) -> Int:\n", " ...\n", "\n", " def __getitem__(self, i: Int) -> Int:\n", " ...\n", "\n", "\n", "converter(\n", " tuple,\n", " IntTuple,\n", " lambda x: (\n", " IntTuple(convert(x[0], Int)) + convert(x[1:], IntTuple)\n", " if len(x) > 1\n", " else (IntTuple(convert(x[0], Int)) if x else IntTuple.empty())\n", " ),\n", ")\n", "converter(int, IntTuple, lambda i: IntTuple(Int(i64(i))))\n", "converter(i64, IntTuple, lambda i: IntTuple(Int(i)))\n", "converter(Int, IntTuple, lambda i: IntTuple(i))\n", "\n", "\n", "@egraph.register\n", "def int_tuple_rules(int_tuple: IntTuple, i: i64, j: i64):\n", " # Handle tuple concatenation and access\n", " yield rewrite(IntTuple(i)[0]).to(Int(i))\n", " yield rewrite((IntTuple(i) + int_tuple)[0]).to(Int(i))\n", " yield rewrite((IntTuple(i) + int_tuple)[j]).to(int_tuple[Int(j - 1)], j > 0)\n", "\n", "\n", "Shape = IntTuple\n", "\n", "\n", "@egraph.class_\n", "class Tensor(Expr):\n", " def __init__(self, name: StringLike, shape: Shape) -> None:\n", " ...\n", "\n", " @property\n", " def shape(self) -> Shape:\n", " ...\n", "\n", " @property\n", " def ndim(self) -> i64:\n", " ...\n", "\n", "\n", "@egraph.register\n", "def inline_tensor_shape(x: Tensor, name: String, shape: Shape, i: Int):\n", " yield rewrite(Tensor(name, shape).shape).to(shape)\n", "\n", "\n", "@egraph.class_\n", "class UnaryOp(Expr):\n", " def __call__(self, x: Tensor) -> Tensor:\n", " ...\n", "\n", "\n", "@egraph.class_\n", "class BinaryOp(Expr):\n", " def __call__(self, x: Tensor, y: Tensor) -> Tensor:\n", " ...\n", "\n", "\n", "@egraph.function(cost=1)\n", "def Squeeze(axis: IntTuple) -> UnaryOp:\n", " ...\n", "\n", "\n", "def split_merge_reorder_axis_op(op, x: Tensor, axis: IntTuple, i: i64, j: i64):\n", " # Split into consecutive axis applications\n", " yield (rewrite(op(axis=(i,) + axis)(x)).to(op(axis=(i,))(op(axis=axis)(x))))\n", " # Swap consecutive axis applications\n", " yield (\n", " rewrite(op(axis=(i,))(op(axis=(j,))(x))).to(\n", " op(axis=(j - 1,))(op(axis=(i,))(x)),\n", " i < j,\n", " )\n", " )\n", " yield (\n", " rewrite(op(axis=(i,))(op(axis=(j,))(x))).to(\n", " op(axis=(j,))(op(axis=(i + 1,))(x)),\n", " i > j,\n", " )\n", " )\n", " # Merge from consecutive axis applications\n", " yield (\n", " rewrite(op(axis=(i,))(op(axis=(j,))(x))).to(\n", " op(axis=(i, j))(x),\n", " i < j,\n", " )\n", " )\n", " yield (\n", " rewrite(op(axis=(i,))(op(axis=(j,) + axis)(x))).to(\n", " op(axis=(i,) + (j + axis))(x),\n", " i < j,\n", " )\n", " )\n", "\n", "\n", "@egraph.register\n", "def squeeze_rules(\n", " x: Tensor,\n", " axis: IntTuple,\n", " i: i64,\n", " j: i64,\n", "):\n", " yield from split_merge_reorder_axis_op(Squeeze, x, axis, i, j)\n", "\n", " # Squeeze.shape\n", " yield (rewrite(Squeeze(axis=(i,))(x).shape[j]).to(x.shape[j + 1], i <= j))\n", " yield (rewrite(Squeeze(axis=(i,))(x).shape[j]).to(x.shape[j], i > j))\n", "\n", "\n", "@egraph.class_\n", "class OpType(Expr):\n", " ...\n", "\n", "\n", "ScalarAdd = egraph.constant(\"ScalarAdd\", OpType)\n", "ScalarMul = egraph.constant(\"ScalarMul\", OpType)\n", "ScalarDiv = egraph.constant(\"ScalarDiv\", OpType)\n", "\n", "\n", "@egraph.function(cost=10)\n", "def Reduce(scalar_op: OpType, axis: IntTuple) -> UnaryOp:\n", " ...\n", "\n", "\n", "@egraph.function(cost=5)\n", "def Elemwise(scalar_op: OpType) -> BinaryOp:\n", " ...\n", "\n", "\n", "Mul = Elemwise(ScalarMul)\n", "Div = Elemwise(ScalarDiv)\n", "Sum = partial(Reduce, ScalarAdd)\n", "\n", "\n", "@egraph.register\n", "def sum_rules(\n", " x: Tensor,\n", " y: Tensor,\n", " z: Tensor,\n", " i: i64,\n", " j: i64,\n", " axis: IntTuple,\n", " axis2: IntTuple,\n", " op: OpType,\n", "):\n", " any_reduce_op = partial(Reduce, op)\n", " yield from split_merge_reorder_axis_op(any_reduce_op, x, axis, i, j)\n", "\n", " # Introduce shape[i] needed for removing useless reductions\n", " yield (rule(eq(y).to(any_reduce_op(axis=(i,))(x))).then(x.shape[i]))\n", "\n", " # Remove useless reductions\n", " yield (rewrite(any_reduce_op(axis=(i,))(x)).to(Squeeze(axis=(i,))(x), eq(x.shape[i]).to(Int(1))))\n", "\n", " # Introduce shape[i] needed for factoring out multiplication/division out of sum\n", " yield (rule(eq(z).to(Sum(axis=(i,))(Elemwise(op)(x, y)))).then(y.shape[i]))\n", "\n", " # Factor multiplication/division out of sum\n", " for elemwise_op in (Mul, Div):\n", " yield (\n", " rewrite(Sum(axis=(i,))(elemwise_op(x, y))).to(\n", " elemwise_op(\n", " Sum(axis=(i,))(x),\n", " Squeeze(axis=(i,))(y),\n", " ),\n", " eq(y.shape[i]).to(Int(1)),\n", " )\n", " )\n", "\n", "\n", "x = Tensor(\"x\", (Int.var(\"x_dim_0\"), 5, 7))\n", "y = Tensor(\"y\", (1, 5, 1))\n", "expr = Sum(axis=(0, 2))(Div(x, y))\n", "# expr = Sum(axis=(0, 2))(y)\n", "\n", "egraph.register(expr)\n", "# egraph" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "rAcuuaK55_ks", "outputId": "c3fe65ef-038b-4b4f-cb7f-a3e9fddc6301" }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "%3\n", "\n", "\n", "outer_cluster_8\n", "\n", "\n", "cluster_8\n", "\n", "\n", "\n", "outer_cluster_3\n", "\n", "\n", "cluster_3\n", "\n", "\n", "\n", "outer_cluster_35\n", "\n", "\n", "cluster_35\n", "\n", "\n", "\n", "outer_cluster_40\n", "\n", "\n", "cluster_40\n", "\n", "\n", "\n", "outer_cluster_51\n", "\n", "\n", "cluster_51\n", "\n", "\n", "\n", "outer_cluster_13\n", "\n", "\n", "cluster_13\n", "\n", "\n", "\n", "outer_cluster_71\n", "\n", "\n", "cluster_71\n", "\n", "\n", "\n", "outer_cluster_62\n", "\n", "\n", "cluster_62\n", "\n", "\n", "\n", "outer_cluster_9\n", "\n", "\n", "cluster_9\n", "\n", "\n", "\n", "outer_cluster_37\n", "\n", "\n", "cluster_37\n", "\n", "\n", "\n", "outer_cluster_1\n", "\n", "\n", "cluster_1\n", "\n", "\n", "\n", "outer_cluster_66\n", "\n", "\n", "cluster_66\n", "\n", "\n", "\n", "outer_cluster_11\n", "\n", "\n", "cluster_11\n", "\n", "\n", "\n", "outer_cluster_61\n", "\n", "\n", "cluster_61\n", "\n", "\n", "\n", "outer_cluster_41\n", "\n", "\n", "cluster_41\n", "\n", "\n", "\n", "outer_cluster_10\n", "\n", "\n", "cluster_10\n", "\n", "\n", "\n", "outer_cluster_2\n", "\n", "\n", "cluster_2\n", "\n", "\n", "\n", "outer_cluster_60\n", "\n", "\n", "cluster_60\n", "\n", "\n", "\n", "outer_cluster_65\n", "\n", "\n", "cluster_65\n", "\n", "\n", "\n", "outer_cluster_5\n", "\n", "\n", "cluster_5\n", "\n", "\n", "\n", "outer_cluster_19\n", "\n", "\n", "cluster_19\n", "\n", "\n", "\n", "outer_cluster_4\n", "\n", "\n", "cluster_4\n", "\n", "\n", "\n", "outer_cluster_20\n", "\n", "\n", "cluster_20\n", "\n", "\n", "\n", "outer_cluster_12\n", "\n", "\n", "cluster_12\n", "\n", "\n", "\n", "outer_cluster_29\n", "\n", "\n", "cluster_29\n", "\n", "\n", "\n", "outer_cluster_36\n", "\n", "\n", "cluster_36\n", "\n", "\n", "\n", "outer_cluster_48\n", "\n", "\n", "cluster_48\n", "\n", "\n", "\n", "outer_cluster_15\n", "\n", "\n", "cluster_15\n", "\n", "\n", "\n", "outer_cluster_50\n", "\n", "\n", "cluster_50\n", "\n", "\n", "\n", "outer_cluster_70\n", "\n", "\n", "cluster_70\n", "\n", "\n", "\n", "outer_cluster_34\n", "\n", "\n", "cluster_34\n", "\n", "\n", "\n", "outer_cluster_39\n", "\n", "\n", "cluster_39\n", "\n", "\n", "\n", "outer_cluster_14\n", "\n", "\n", "cluster_14\n", "\n", "\n", "\n", "outer_cluster_0\n", "\n", "\n", "cluster_0\n", "\n", "\n", "\n", "outer_cluster_7\n", "\n", "\n", "cluster_7\n", "\n", "\n", "\n", "outer_cluster_String-1618856268753343266\n", "\n", "\n", "cluster_String-1618856268753343266\n", "\n", "\n", "\n", "outer_cluster_String-313937328870544552\n", "\n", "\n", "cluster_String-313937328870544552\n", "\n", "\n", "\n", "outer_cluster_String-6766843149853318713\n", "\n", "\n", "cluster_String-6766843149853318713\n", "\n", "\n", "\n", "outer_cluster_22\n", "\n", "\n", "cluster_22\n", "\n", "\n", "\n", "outer_cluster_45\n", "\n", "\n", "cluster_45\n", "\n", "\n", "\n", "outer_cluster_23\n", "\n", "\n", "cluster_23\n", "\n", "\n", "\n", "outer_cluster_58\n", "\n", "\n", "cluster_58\n", "\n", "\n", "\n", "outer_cluster_52\n", "\n", "\n", "cluster_52\n", "\n", "\n", "\n", "outer_cluster_54\n", "\n", "\n", "cluster_54\n", "\n", "\n", "\n", "outer_cluster_56\n", "\n", "\n", "cluster_56\n", "\n", "\n", "\n", "outer_cluster_24\n", "\n", "\n", "cluster_24\n", "\n", "\n", "\n", "outer_cluster_32\n", "\n", "\n", "cluster_32\n", "\n", "\n", "\n", "outer_cluster_43\n", "\n", "\n", "cluster_43\n", "\n", "\n", "\n", "outer_cluster_27\n", "\n", "\n", "cluster_27\n", "\n", "\n", "\n", "outer_cluster_17\n", "\n", "\n", "cluster_17\n", "\n", "\n", "\n", "outer_cluster_67\n", "\n", "\n", "cluster_67\n", "\n", "\n", "\n", "outer_cluster_26\n", "\n", "\n", "cluster_26\n", "\n", "\n", "\n", "outer_cluster_25\n", "\n", "\n", "cluster_25\n", "\n", "\n", "\n", "outer_cluster_53\n", "\n", "\n", "cluster_53\n", "\n", "\n", "\n", "outer_cluster_6\n", "\n", "\n", "cluster_6\n", "\n", "\n", "\n", "outer_cluster_44\n", "\n", "\n", "cluster_44\n", "\n", "\n", "\n", "outer_cluster_57\n", "\n", "\n", "cluster_57\n", "\n", "\n", "\n", "outer_cluster_31\n", "\n", "\n", "cluster_31\n", "\n", "\n", "\n", "outer_cluster_i64-4208978898528913939\n", "\n", "\n", "cluster_i64-4208978898528913939\n", "\n", "\n", "\n", "outer_cluster_i64-5871781006564002453\n", "\n", "\n", "cluster_i64-5871781006564002453\n", "\n", "\n", "\n", "outer_cluster_i64-10912160959110460649\n", "\n", "\n", "cluster_i64-10912160959110460649\n", "\n", "\n", "\n", "outer_cluster_i64-11743562013128004906\n", "\n", "\n", "cluster_i64-11743562013128004906\n", "\n", "\n", "\n", "outer_cluster_i64-0\n", "\n", "\n", "cluster_i64-0\n", "\n", "\n", "\n", "\n", "Elemwise-4208978898528913939:s->ScalarDiv-0\n", "\n", "\n", "\n", "\n", "\n", "Int___init__-11743562013128004906:s->i64-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-2427221081489190903:s->Tensor_shape-10964134587551653303\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-2427221081489190903:s->Int___init__-0\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-10964134587551653303:s->UnaryOp___call__-8391183972302725451\n", "\n", "\n", "\n", "\n", "\n", "Int___init__-0:s->i64-0\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-14535344522394967184:s->Tensor_shape-3429551472952562336\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-14535344522394967184:s->IntTuple___getitem__-8609920828969388276\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-3429551472952562336:s->UnaryOp___call__-10760501220358611293\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-8609920828969388276:s->Int___init__-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-8609920828969388276:s->IntTuple___add__-3515898189561353625\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-6451836211519040652:s->Int___init__-0\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-6451836211519040652:s->IntTuple___add__-16058157226959687954\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___add__-16058157226959687954:s->IntTuple___init__-15952540911656918845\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___add__-16058157226959687954:s->IntTuple___add__-14593153318476900644\n", "\n", "\n", "\n", "\n", "\n", "Int___init__-4208978898528913939:s->i64-4208978898528913939\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-6611095213683386216:s->Int___init__-0\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-6611095213683386216:s->Tensor_shape-10184707161975301700\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-10184707161975301700:s->UnaryOp___call__-10053986080337813965\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-18195398224647045558:s->Int___init__-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-18195398224647045558:s->Tensor_shape-7586556743040283621\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-7586556743040283621:s->Tensor___init__-1575189424824699418\n", "\n", "\n", "\n", "\n", "\n", "Int_var-1618856268753343266:s->String-1618856268753343266\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-7613425605146757649:s->Int___init__-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-7613425605146757649:s->Tensor_shape-5923754635005195107\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-5923754635005195107:s->BinaryOp___call__-4194065315749295272\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-14456597069198672103:s->Tensor_shape-12678910324027934471\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-14456597069198672103:s->IntTuple___getitem__-10395939792580839918\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-12678910324027934471:s->UnaryOp___call__-9097699112323522779\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-10395939792580839918:s->IntTuple___add__-5343794467401528509\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-10395939792580839918:s->IntTuple___getitem__-1906738768387841566\n", "\n", "\n", "\n", "\n", "\n", "Int___init__-10912160959110460649:s->i64-10912160959110460649\n", "\n", "\n", "\n", "\n", "\n", "Int___init__-5871781006564002453:s->i64-5871781006564002453\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-14040062012979708074:s->Int___init__-0\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-14040062012979708074:s->Tensor_shape-3481525101393754990\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-3481525101393754990:s->UnaryOp___call__-3344968852115526449\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___add__-5343794467401528509:s->IntTuple___init__-7690503999922668929\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___add__-5343794467401528509:s->IntTuple___init__-9249358851075372135\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-1906738768387841566:s->Int___init__-0\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-1906738768387841566:s->Tensor_shape-51973628441192654\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-10514990538936846041:s->IntTuple___getitem__-14040062012979708074\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-10514990538936846041:s->Tensor_shape-5975728263446387761\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-5975728263446387761:s->UnaryOp___call__-4195639504827048526\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___add__-3515898189561353625:s->IntTuple___add__-5343794467401528509\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___add__-3515898189561353625:s->IntTuple___init__-7690503999922668929\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-9335705567684163424:s->Int___init__-0\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-9335705567684163424:s->IntTuple___init__-7690503999922668929\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___init__-7690503999922668929:s->IntTuple___getitem__-10395939792580839918\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-51973628441192654:s->Tensor___init__-14358371401161707118\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-910243544565210939:s->Int___init__-0\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-910243544565210939:s->Tensor_shape-5923754635005195107\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___init__-15952540911656918845:s->Int_var-1618856268753343266\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___init__-5871781006564002453:s->Int___init__-0\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-3344968852115526449:s->Tensor___init__-14358371401161707118\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-3344968852115526449:s->Squeeze-5040379952546458196\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-9097699112323522779:s->Tensor___init__-1575189424824699418\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-9097699112323522779:s->Reduce-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___add__-15318938057191675792:s->IntTuple___init__-5871781006564002453\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___add__-15318938057191675792:s->IntTuple___init__-17615343019692007359\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___init__-17615343019692007359:s->Int___init__-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___init__-9249358851075372135:s->Int___init__-10912160959110460649\n", "\n", "\n", "\n", "\n", "\n", "Tensor___init__-14358371401161707118:s->Tensor_shape-51973628441192654\n", "\n", "\n", "\n", "\n", "\n", "Tensor___init__-14358371401161707118:s->String-6766843149853318713\n", "\n", "\n", "\n", "\n", "\n", "BinaryOp___call__-4194065315749295272:s->Elemwise-4208978898528913939\n", "\n", "\n", "\n", "\n", "\n", "BinaryOp___call__-4194065315749295272:s->Tensor___init__-14358371401161707118\n", "\n", "\n", "\n", "\n", "\n", "BinaryOp___call__-4194065315749295272:s->Tensor___init__-1575189424824699418\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-4195639504827048526:s->Tensor___init__-14358371401161707118\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-4195639504827048526:s->Squeeze-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___add__-14593153318476900644:s->IntTuple___init__-9249358851075372135\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___add__-14593153318476900644:s->IntTuple___init__-2546176790493825425\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___init__-2546176790493825425:s->Int___init__-4208978898528913939\n", "\n", "\n", "\n", "\n", "\n", "Tensor___init__-1575189424824699418:s->Tensor_shape-7586556743040283621\n", "\n", "\n", "\n", "\n", "\n", "Tensor___init__-1575189424824699418:s->String-313937328870544552\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-10053986080337813965:s->Tensor___init__-1575189424824699418\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-10053986080337813965:s->Reduce-5040379952546458196\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-8391183972302725451:s->BinaryOp___call__-4194065315749295272\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-8391183972302725451:s->Reduce-5040379952546458196\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-10760501220358611293:s->BinaryOp___call__-4194065315749295272\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-10760501220358611293:s->Reduce-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "Squeeze-11743562013128004906:s->IntTuple___init__-5871781006564002453\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-3636617994337743549:s->UnaryOp___call__-4195639504827048526\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-3636617994337743549:s->Squeeze-883374682458736911\n", "\n", "\n", "\n", "\n", "\n", "Squeeze-883374682458736911:s->IntTuple___init__-7690503999922668929\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-7625190977779610862:s->UnaryOp___call__-3344968852115526449\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-7625190977779610862:s->Squeeze-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-2183379458487809452:s->Tensor___init__-14358371401161707118\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-2183379458487809452:s->Squeeze-10912160959110460649\n", "\n", "\n", "\n", "\n", "\n", "Squeeze-10912160959110460649:s->IntTuple___add__-15318938057191675792\n", "\n", "\n", "\n", "\n", "\n", "Reduce-5040379952546458196:s->IntTuple___init__-17615343019692007359\n", "\n", "\n", "\n", "\n", "\n", "Reduce-5040379952546458196:s->ScalarAdd-0\n", "\n", "\n", "\n", "\n", "\n", "Squeeze-5040379952546458196:s->IntTuple___init__-17615343019692007359\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-15458927460662043536:s->UnaryOp___call__-9097699112323522779\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-15458927460662043536:s->Reduce-883374682458736911\n", "\n", "\n", "\n", "\n", "\n", "Reduce-883374682458736911:s->IntTuple___init__-7690503999922668929\n", "\n", "\n", "\n", "\n", "\n", "Reduce-883374682458736911:s->ScalarAdd-0\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-13202730753970051410:s->UnaryOp___call__-10053986080337813965\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-13202730753970051410:s->Reduce-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "Reduce-11743562013128004906:s->IntTuple___init__-5871781006564002453\n", "\n", "\n", "\n", "\n", "\n", "Reduce-11743562013128004906:s->ScalarAdd-0\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-593394598656903612:s->Tensor___init__-1575189424824699418\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-593394598656903612:s->Reduce-10912160959110460649\n", "\n", "\n", "\n", "\n", "\n", "Reduce-10912160959110460649:s->IntTuple___add__-15318938057191675792\n", "\n", "\n", "\n", "\n", "\n", "Reduce-10912160959110460649:s->ScalarAdd-0\n", "\n", "\n", "\n", "\n", "\n", "BinaryOp___call__-7183646024568558754:s->Elemwise-4208978898528913939\n", "\n", "\n", "\n", "\n", "\n", "BinaryOp___call__-7183646024568558754:s->UnaryOp___call__-7625190977779610862\n", "\n", "\n", "\n", "\n", "\n", "BinaryOp___call__-7183646024568558754:s->UnaryOp___call__-13202730753970051410\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-7296576659238450322:s->BinaryOp___call__-4194065315749295272\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-7296576659238450322:s->Reduce-10912160959110460649\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-5720121267812153097:s->Reduce-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-5720121267812153097:s->BinaryOp___call__-17648236956133224341\n", "\n", "\n", "\n", "\n", "\n", "BinaryOp___call__-17648236956133224341:s->Elemwise-4208978898528913939\n", "\n", "\n", "\n", "\n", "\n", "BinaryOp___call__-17648236956133224341:s->UnaryOp___call__-3344968852115526449\n", "\n", "\n", "\n", "\n", "\n", "BinaryOp___call__-17648236956133224341:s->UnaryOp___call__-10053986080337813965\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-6261542238027864055:s->Reduce-883374682458736911\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-6261542238027864055:s->BinaryOp___call__-4681938636454801376\n", "\n", "\n", "\n", "\n", "\n", "BinaryOp___call__-4681938636454801376:s->Elemwise-4208978898528913939\n", "\n", "\n", "\n", "\n", "\n", "BinaryOp___call__-4681938636454801376:s->UnaryOp___call__-9097699112323522779\n", "\n", "\n", "\n", "\n", "\n", "BinaryOp___call__-4681938636454801376:s->UnaryOp___call__-4195639504827048526\n", "\n", "\n", "\n", "\n", "\n", "Elemwise-4208978898528913939\n", "\n", "\n", "Elemwise\n", "\n", "\n", "\n", "\n", "\n", "\n", "ScalarDiv-0\n", "\n", "\n", "ScalarDiv\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int___init__-11743562013128004906\n", "\n", "\n", "Int___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "i64-11743562013128004906\n", "\n", "\n", "2\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-2427221081489190903\n", "\n", "\n", "IntTuple___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-10964134587551653303\n", "\n", "\n", "Tensor_shape\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int___init__-0\n", "\n", "\n", "Int___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-14535344522394967184\n", "\n", "\n", "IntTuple___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-3429551472952562336\n", "\n", "\n", "Tensor_shape\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-8609920828969388276\n", "\n", "\n", "IntTuple___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-6451836211519040652\n", "\n", "\n", "IntTuple___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___add__-16058157226959687954\n", "\n", "\n", "IntTuple___add__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int___init__-4208978898528913939\n", "\n", "\n", "Int___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "i64-4208978898528913939\n", "\n", "\n", "7\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-6611095213683386216\n", "\n", "\n", "IntTuple___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-10184707161975301700\n", "\n", "\n", "Tensor_shape\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-18195398224647045558\n", "\n", "\n", "IntTuple___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-7586556743040283621\n", "\n", "\n", "Tensor_shape\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int_var-1618856268753343266\n", "\n", "\n", "Int_var\n", "\n", "\n", "\n", "\n", "\n", "\n", "String-1618856268753343266\n", "\n", "\n", ""x_dim_0"\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-7613425605146757649\n", "\n", "\n", "IntTuple___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-5923754635005195107\n", "\n", "\n", "Tensor_shape\n", "\n", "\n", "\n", "\n", "\n", "\n", "i64-0\n", "\n", "\n", "0\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-14456597069198672103\n", "\n", "\n", "IntTuple___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-12678910324027934471\n", "\n", "\n", "Tensor_shape\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-10395939792580839918\n", "\n", "\n", "IntTuple___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int___init__-10912160959110460649\n", "\n", "\n", "Int___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "i64-10912160959110460649\n", "\n", "\n", "5\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int___init__-5871781006564002453\n", "\n", "\n", "Int___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "i64-5871781006564002453\n", "\n", "\n", "1\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-14040062012979708074\n", "\n", "\n", "IntTuple___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-3481525101393754990\n", "\n", "\n", "Tensor_shape\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___add__-5343794467401528509\n", "\n", "\n", "IntTuple___add__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-1906738768387841566\n", "\n", "\n", "IntTuple___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-10514990538936846041\n", "\n", "\n", "IntTuple___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-5975728263446387761\n", "\n", "\n", "Tensor_shape\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___add__-3515898189561353625\n", "\n", "\n", "IntTuple___add__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-9335705567684163424\n", "\n", "\n", "IntTuple___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___init__-7690503999922668929\n", "\n", "\n", "IntTuple___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Tensor_shape-51973628441192654\n", "\n", "\n", "Tensor_shape\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___getitem__-910243544565210939\n", "\n", "\n", "IntTuple___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___init__-15952540911656918845\n", "\n", "\n", "IntTuple___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___init__-5871781006564002453\n", "\n", "\n", "IntTuple___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-3344968852115526449\n", "\n", "\n", "UnaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-9097699112323522779\n", "\n", "\n", "UnaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___add__-15318938057191675792\n", "\n", "\n", "IntTuple___add__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___init__-17615343019692007359\n", "\n", "\n", "IntTuple___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___init__-9249358851075372135\n", "\n", "\n", "IntTuple___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Tensor___init__-14358371401161707118\n", "\n", "\n", "Tensor___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "BinaryOp___call__-4194065315749295272\n", "\n", "\n", "BinaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-4195639504827048526\n", "\n", "\n", "UnaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___add__-14593153318476900644\n", "\n", "\n", "IntTuple___add__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IntTuple___init__-2546176790493825425\n", "\n", "\n", "IntTuple___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Tensor___init__-1575189424824699418\n", "\n", "\n", "Tensor___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-10053986080337813965\n", "\n", "\n", "UnaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-8391183972302725451\n", "\n", "\n", "UnaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-10760501220358611293\n", "\n", "\n", "UnaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "ScalarAdd-0\n", "\n", "\n", "ScalarAdd\n", "\n", "\n", "\n", "\n", "\n", "\n", "String-313937328870544552\n", "\n", "\n", ""x"\n", "\n", "\n", "\n", "\n", "\n", "\n", "String-6766843149853318713\n", "\n", "\n", ""y"\n", "\n", "\n", "\n", "\n", "\n", "\n", "Squeeze-11743562013128004906\n", "\n", "\n", "Squeeze\n", "\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-3636617994337743549\n", "\n", "\n", "UnaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Squeeze-883374682458736911\n", "\n", "\n", "Squeeze\n", "\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-7625190977779610862\n", "\n", "\n", "UnaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-2183379458487809452\n", "\n", "\n", "UnaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Squeeze-10912160959110460649\n", "\n", "\n", "Squeeze\n", "\n", "\n", "\n", "\n", "\n", "\n", "Reduce-5040379952546458196\n", "\n", "\n", "Reduce\n", "\n", "\n", "\n", "\n", "\n", "\n", "Squeeze-5040379952546458196\n", "\n", "\n", "Squeeze\n", "\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-15458927460662043536\n", "\n", "\n", "UnaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Reduce-883374682458736911\n", "\n", "\n", "Reduce\n", "\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-13202730753970051410\n", "\n", "\n", "UnaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Reduce-11743562013128004906\n", "\n", "\n", "Reduce\n", "\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-593394598656903612\n", "\n", "\n", "UnaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Reduce-10912160959110460649\n", "\n", "\n", "Reduce\n", "\n", "\n", "\n", "\n", "\n", "\n", "BinaryOp___call__-7183646024568558754\n", "\n", "\n", "BinaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-7296576659238450322\n", "\n", "\n", "UnaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-5720121267812153097\n", "\n", "\n", "UnaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "BinaryOp___call__-17648236956133224341\n", "\n", "\n", "BinaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "UnaryOp___call__-6261542238027864055\n", "\n", "\n", "UnaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "\n", "BinaryOp___call__-4681938636454801376\n", "\n", "\n", "BinaryOp___call__\n", "\n", "\n", "\n", "\n", "\n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "egraph.run(20)\n", "egraph.display(n_inline_leaves=0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 70 }, "id": "xa_NKCF74p4H", "outputId": "ab6a53d7-7862-4a54-a996-2308631a038c" }, "outputs": [ { "data": { "text/html": [ "
Reduce(ScalarAdd, IntTuple(Int(0)) + IntTuple(Int(2)))(\n",
              "    Elemwise(ScalarDiv)(Tensor("x", IntTuple(Int.var("x_dim_0")) + (IntTuple(Int(5)) + IntTuple(Int(7)))), Tensor("y", IntTuple(Int(1)) + (IntTuple(Int(5)) + IntTuple(Int(1)))))\n",
              ")\n",
              "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{n}{Reduce}\\PY{p}{(}\\PY{n}{ScalarAdd}\\PY{p}{,} \\PY{n}{IntTuple}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{IntTuple}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{(}\n", " \\PY{n}{Elemwise}\\PY{p}{(}\\PY{n}{ScalarDiv}\\PY{p}{)}\\PY{p}{(}\\PY{n}{Tensor}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{x}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{n}{IntTuple}\\PY{p}{(}\\PY{n}{Int}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{x\\PYZus{}dim\\PYZus{}0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{p}{(}\\PY{n}{IntTuple}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{5}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{IntTuple}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{7}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Tensor}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{y}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{n}{IntTuple}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{p}{(}\\PY{n}{IntTuple}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{5}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{IntTuple}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{p}{)}\n", "\\end{Verbatim}\n" ], "text/plain": [ "Reduce(ScalarAdd, IntTuple(Int(0)) + IntTuple(Int(2)))(\n", " Elemwise(ScalarDiv)(Tensor(\"x\", IntTuple(Int.var(\"x_dim_0\")) + (IntTuple(Int(5)) + IntTuple(Int(7)))), Tensor(\"y\", IntTuple(Int(1)) + (IntTuple(Int(5)) + IntTuple(Int(1)))))\n", ")" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "egraph.extract(expr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "uVjDI5hr6fZ3", "outputId": "4b05f145-14cc-41a5-a5e5-0a2db7019adc" }, "outputs": [ { "data": { "text/plain": [ "[Elemwise(ScalarDiv)(\n", " Reduce(ScalarAdd, IntTuple(Int(0)) + IntTuple(Int(2)))(Tensor(\"x\", IntTuple(Int.var(\"x_dim_0\")) + (IntTuple(Int(5)) + IntTuple(Int(7))))),\n", " Squeeze(IntTuple(Int(0)) + IntTuple(Int(2)))(Tensor(\"y\", IntTuple(Int(1)) + (IntTuple(Int(5)) + IntTuple(Int(1))))),\n", " ),\n", " Reduce(ScalarAdd, IntTuple(Int(0)) + IntTuple(Int(2)))(\n", " Elemwise(ScalarDiv)(Tensor(\"x\", IntTuple(Int.var(\"x_dim_0\")) + (IntTuple(Int(5)) + IntTuple(Int(7)))), Tensor(\"y\", IntTuple(Int(1)) + (IntTuple(Int(5)) + IntTuple(Int(1)))))\n", " ),\n", " Reduce(ScalarAdd, IntTuple(Int(0)))(\n", " Reduce(ScalarAdd, IntTuple(Int(2)))(\n", " Elemwise(ScalarDiv)(\n", " Tensor(\"x\", IntTuple(Int.var(\"x_dim_0\")) + (IntTuple(Int(5)) + IntTuple(Int(7)))), Tensor(\"y\", IntTuple(Int(1)) + (IntTuple(Int(5)) + IntTuple(Int(1))))\n", " )\n", " )\n", " ),\n", " Reduce(ScalarAdd, IntTuple(Int(1)))(\n", " Reduce(ScalarAdd, IntTuple(Int(0)))(\n", " Elemwise(ScalarDiv)(\n", " Tensor(\"x\", IntTuple(Int.var(\"x_dim_0\")) + (IntTuple(Int(5)) + IntTuple(Int(7)))), Tensor(\"y\", IntTuple(Int(1)) + (IntTuple(Int(5)) + IntTuple(Int(1))))\n", " )\n", " )\n", " )]" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "egraph.extract_multiple(expr, 10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mw6a1kafJNax" }, "outputs": [], "source": [ "egraph.check(\n", " eq(Sum(axis=(0, 2))(Div(x, y))).to(\n", " # Sum(axis=0)(Mul(Sum(axis=0)(x), Squeeze(axis=0)(y)))\n", " Div(\n", " Sum(axis=(0, 2))(x),\n", " Squeeze(axis=(0, 2))(y),\n", " )\n", " # Sum(axis=(0,))(Sum(axis=(1, 2))(y))\n", " # Squeeze(axis=(0, 2))(y)\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NgovSVpBRpu0" }, "outputs": [], "source": [] } ], "metadata": { "mystnb": { "execution_mode": "off" }, "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }