{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "```{post} 2023-11-12\n", ":author: Saul\n", "```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Indexing pushdown\n", "\n", "I got this question today, and I thought I would write up some examples to explain the current state of things, for this sort of indexing pushdown:\n", "\n", "Q: How would I do a `exp(vec)[idx] -> exp(vec[idx])` rewrite?\n", "\n", "A: You can easily write this rewrite to add both expressions to the graph, but it's currently difficult to extract out the right expression over the left one:\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "outer_cluster_3\n", "\n", "\n", "cluster_3\n", "\n", "\n", "\n", "outer_cluster_5\n", "\n", "\n", "cluster_5\n", "\n", "\n", "\n", "outer_cluster_2\n", "\n", "\n", "cluster_2\n", "\n", "\n", "\n", "outer_cluster_1\n", "\n", "\n", "cluster_1\n", "\n", "\n", "\n", "outer_cluster_0\n", "\n", "\n", "cluster_0\n", "\n", "\n", "\n", "outer_cluster_4\n", "\n", "\n", "cluster_4\n", "\n", "\n", "\n", "outer_cluster_String-7017119389452091262\n", "\n", "\n", "cluster_String-7017119389452091262\n", "\n", "\n", "\n", "outer_cluster_String-9853858587155302541\n", "\n", "\n", "cluster_String-9853858587155302541\n", "\n", "\n", "\n", "\n", "ndarray_index-11743562013128004906:s->NDArray_var-9853858587155302541\n", "\n", "\n", "\n", "\n", "\n", "NDArray_var-9853858587155302541:s->String-9853858587155302541\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-17615343019692007359:s->ndarray_index-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-17615343019692007359:s->NDArray_var-7017119389452091262\n", "\n", "\n", "\n", "\n", "\n", "NDArray_var-7017119389452091262:s->String-7017119389452091262\n", "\n", "\n", "\n", "\n", "\n", "ndarray-abs-0:s->NDArray_var-7017119389452091262\n", "\n", "\n", "\n", "\n", "\n", "ndarray-abs-10912160959110460649:s->NDArray___getitem__-17615343019692007359\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-13531250035159840349:s->ndarray_index-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-13531250035159840349:s->ndarray-abs-0\n", "\n", "\n", "\n", "\n", "\n", "ndarray_index-11743562013128004906\n", "\n", "\n", "ndarray_index\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_var-9853858587155302541\n", "\n", "\n", "NDArray_var\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-17615343019692007359\n", "\n", "\n", "NDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_var-7017119389452091262\n", "\n", "\n", "NDArray_var\n", "\n", "\n", "\n", "\n", "\n", "\n", "String-9853858587155302541\n", "\n", "\n", ""idx"\n", "\n", "\n", "\n", "\n", "\n", "\n", "ndarray-abs-0\n", "\n", "\n", "ndarray-abs\n", "\n", "\n", "\n", "\n", "\n", "\n", "String-7017119389452091262\n", "\n", "\n", ""x"\n", "\n", "\n", "\n", "\n", "\n", "\n", "ndarray-abs-10912160959110460649\n", "\n", "\n", "ndarray-abs\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-13531250035159840349\n", "\n", "\n", "NDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "abs(NDArray.var(\"x\")[ndarray_index(NDArray.var(\"idx\"))])\n", "abs(NDArray.var(\"x\"))[ndarray_index(NDArray.var(\"idx\"))]\n" ] } ], "source": [ "from __future__ import annotations\n", "\n", "from egglog.exp.array_api import *\n", "\n", "egraph = EGraph([array_api_module])\n", "\n", "\n", "@egraph.register\n", "def _pushdown_abs(x: NDArray, idx: IndexKey):\n", " yield rewrite(abs(x)[idx]).to(abs(x[idx]))\n", "\n", "\n", "res = abs(NDArray.var(\"x\"))[NDArray.var(\"idx\")]\n", "egraph.register(res)\n", "egraph.run(100)\n", "egraph.display()\n", "\n", "for e in egraph.extract_multiple(res, 10):\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see here it extracts out the two objects. If we ask it to just extract out the lowest cost one, it will be non-deterministic which is selected.\n", "\n", "See [this issue](https://github.com/egraphs-good/egglog/issues/256#issuecomment-1807185387) on the egglog tracker for how this could be resolved.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "Another way to resolve this, would be to actually try to define the _semantics_ of these two operations. We can make up a mathematical abstraction for arrays, and define both indexing and the `abs` function in terms of that abstraction. Then when we compose them, we can look at the composition of the abstractions, to see if that normalizing that form to a canonical one can also achieve the predicate pushdown optimization.\n", "\n", "In this case, we can pick an abstraction where each array is defined by:\n", "\n", "1. A shape `.shape`\n", "2. A dtype `.dtype`\n", "3. A mapping from indices to values `x.index(idx)`. Similar to regular indexing `x[idx]`, but only returns the inner value, not a scalar array, so that the defintion is not recursive.\n", "\n", "So in this case, our question would be, what's the shape, dtype, and indexed value of `abs(x)[idx]`?\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "outer_cluster_6\n", "\n", "\n", "cluster_6\n", "\n", "\n", "\n", "outer_cluster_3\n", "\n", "\n", "cluster_3\n", "\n", "\n", "\n", "outer_cluster_2\n", "\n", "\n", "cluster_2\n", "\n", "\n", "\n", "outer_cluster_0\n", "\n", "\n", "cluster_0\n", "\n", "\n", "\n", "outer_cluster_4\n", "\n", "\n", "cluster_4\n", "\n", "\n", "\n", "outer_cluster_1\n", "\n", "\n", "cluster_1\n", "\n", "\n", "\n", "outer_cluster_String-9853858587155302541\n", "\n", "\n", "cluster_String-9853858587155302541\n", "\n", "\n", "\n", "outer_cluster_String-7017119389452091262\n", "\n", "\n", "cluster_String-7017119389452091262\n", "\n", "\n", "\n", "outer_cluster_5\n", "\n", "\n", "cluster_5\n", "\n", "\n", "\n", "outer_cluster_10\n", "\n", "\n", "cluster_10\n", "\n", "\n", "\n", "outer_cluster_7\n", "\n", "\n", "cluster_7\n", "\n", "\n", "\n", "outer_cluster_12\n", "\n", "\n", "cluster_12\n", "\n", "\n", "\n", "outer_cluster_8\n", "\n", "\n", "cluster_8\n", "\n", "\n", "\n", "\n", "NDArray_dtype-5040379952546458196:s->NDArray___getitem__-13531250035159840349\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-13531250035159840349:s->ndarray-abs-0\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-13531250035159840349:s->ndarray_index-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "NDArray_dtype-5871781006564002453:s->ndarray-abs-0\n", "\n", "\n", "\n", "\n", "\n", "ndarray-abs-0:s->NDArray_var-7017119389452091262\n", "\n", "\n", "\n", "\n", "\n", "NDArray_dtype-0:s->NDArray_var-7017119389452091262\n", "\n", "\n", "\n", "\n", "\n", "NDArray_var-7017119389452091262:s->String-7017119389452091262\n", "\n", "\n", "\n", "\n", "\n", "ndarray_index-11743562013128004906:s->NDArray_var-9853858587155302541\n", "\n", "\n", "\n", "\n", "\n", "NDArray_var-9853858587155302541:s->String-9853858587155302541\n", "\n", "\n", "\n", "\n", "\n", "NDArray_shape-5040379952546458196:s->NDArray___getitem__-13531250035159840349\n", "\n", "\n", "\n", "\n", "\n", "translate_index-9163280144112643440:s->ndarray-abs-0\n", "\n", "\n", "\n", "\n", "\n", "translate_index-9163280144112643440:s->ndarray_index-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "translate_index-9163280144112643440:s->an_index-0\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-3377577844511369682:s->NDArray_var-7017119389452091262\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-3377577844511369682:s->translate_index-9163280144112643440\n", "\n", "\n", "\n", "\n", "\n", "value_abs-15121139857639374588:s->NDArray_index-3377577844511369682\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-4604575297633516347:s->NDArray___getitem__-13531250035159840349\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-4604575297633516347:s->an_index-0\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-16025453197212473120:s->ndarray-abs-0\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-16025453197212473120:s->translate_index-9163280144112643440\n", "\n", "\n", "\n", "\n", "\n", "NDArray_dtype-5040379952546458196\n", "\n", "\n", "NDArray_dtype\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-13531250035159840349\n", "\n", "\n", "NDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_dtype-5871781006564002453\n", "\n", "\n", "NDArray_dtype\n", "\n", "\n", "\n", "\n", "\n", "\n", "ndarray-abs-0\n", "\n", "\n", "ndarray-abs\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_dtype-0\n", "\n", "\n", "NDArray_dtype\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_var-7017119389452091262\n", "\n", "\n", "NDArray_var\n", "\n", "\n", "\n", "\n", "\n", "\n", "ndarray_index-11743562013128004906\n", "\n", "\n", "ndarray_index\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_var-9853858587155302541\n", "\n", "\n", "NDArray_var\n", "\n", "\n", "\n", "\n", "\n", "\n", "String-9853858587155302541\n", "\n", "\n", ""idx"\n", "\n", "\n", "\n", "\n", "\n", "\n", "String-7017119389452091262\n", "\n", "\n", ""x"\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_shape-5040379952546458196\n", "\n", "\n", "NDArray_shape\n", "\n", "\n", "\n", "\n", "\n", "\n", "translate_index-9163280144112643440\n", "\n", "\n", "translate_index\n", "\n", "\n", "\n", "\n", "\n", "\n", "an_index-0\n", "\n", "\n", "an_index\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-3377577844511369682\n", "\n", "\n", "NDArray_index\n", "\n", "\n", "\n", "\n", "\n", "\n", "value_abs-15121139857639374588\n", "\n", "\n", "value_abs\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-4604575297633516347\n", "\n", "\n", "NDArray_index\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-16025453197212473120\n", "\n", "\n", "NDArray_index\n", "\n", "\n", "\n", "\n", "\n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Resulting shapes:\n", " abs(NDArray.var(\"x\"))[ndarray_index(NDArray.var(\"idx\"))].shape\n", "Resulting dtypes:\n", " abs(NDArray.var(\"x\"))[ndarray_index(NDArray.var(\"idx\"))].dtype\n", " abs(NDArray.var(\"x\")).dtype\n", " NDArray.var(\"x\").dtype\n", "Resulting indexing:\n", " value_abs(NDArray.var(\"x\").index(translate_index(abs(NDArray.var(\"x\")), ndarray_index(NDArray.var(\"idx\")), an_index())))\n", " abs(NDArray.var(\"x\"))[ndarray_index(NDArray.var(\"idx\"))].index(an_index())\n", " abs(NDArray.var(\"x\")).index(translate_index(abs(NDArray.var(\"x\")), ndarray_index(NDArray.var(\"idx\")), an_index()))\n" ] } ], "source": [ "egraph = EGraph([array_api_module])\n", "\n", "\n", "@egraph.function(cost=0)\n", "def value_abs(v: Value) -> Value:\n", " \"\"\"Absolute value of a scalar value\"\"\"\n", "\n", "\n", "@egraph.register\n", "def _define_abs(x: NDArray, ti: TupleInt):\n", " # dtype after taking absolute value is same dtype\n", " yield rewrite(abs(x).dtype).to(x.dtype)\n", " # shape after taking absolute value is same shape\n", " yield rewrite(abs(x).shape).to(x.shape)\n", " # Indexing into absolute value is same as indexing into original and then taking the absolute value\n", " yield rewrite(abs(x).index(ti)).to(value_abs(x.index(ti)))\n", "\n", "\n", "@egraph.function(cost=0)\n", "def translate_index(x: NDArray, y: IndexKey, z: TupleInt) -> TupleInt:\n", " \"\"\"Translates indexing `z` into `x[y]` into an indexing directly into `x`\"\"\"\n", "\n", "\n", "@egraph.register\n", "def _define_indexing(x: NDArray, idx: IndexKey, ti: TupleInt):\n", " # dtype after indexing is same dtype\n", " yield rewrite(x[idx].dtype).to(x.dtype)\n", " # indxing is pushed down to source array, after some translation\n", " yield rewrite(x[idx].index(ti)).to(x.index(translate_index(x, idx, ti)))\n", " # Shape is more complicated and we will omit for now\n", "\n", "\n", "@egraph.function\n", "def an_index() -> TupleInt:\n", " \"\"\"Some index into an array\"\"\"\n", "\n", "\n", "egraph.register(res.shape, res.dtype, res.index(an_index()))\n", "egraph.run(100)\n", "egraph.display()\n", "\n", "\n", "print(\"Resulting shapes:\")\n", "for e in egraph.extract_multiple(res.shape, 10):\n", " print(\" \", e)\n", "print(\"Resulting dtypes:\")\n", "for e in egraph.extract_multiple(res.dtype, 10):\n", " print(\" \", e)\n", "print(\"Resulting indexing:\")\n", "for e in egraph.extract_multiple(res.index(an_index()), 10):\n", " print(\" \", e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the same reason as above, if we extract out the lowest cost one, it will be non-deterministic which is selected (at least for indexing), but if we look at all of them, we can see that the `dtype` is pushed down to the inner value, and also the indexing is as well.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "egglog-python", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }