{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{post} 2023-11-12\n",
":author: Saul\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Indexing pushdown\n",
"\n",
"I got this question today, and I thought I would write up some examples to explain the current state of things, for this sort of indexing pushdown:\n",
"\n",
"Q: How would I do a `exp(vec)[idx] -> exp(vec[idx])` rewrite?\n",
"\n",
"A: You can easily write this rewrite to add both expressions to the graph, but it's currently difficult to extract out the right expression over the left one:\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"abs(NDArray.var(\"x\")[ndarray_index(NDArray.var(\"idx\"))])\n",
"abs(NDArray.var(\"x\"))[ndarray_index(NDArray.var(\"idx\"))]\n"
]
}
],
"source": [
"from __future__ import annotations\n",
"\n",
"from egglog.exp.array_api import *\n",
"\n",
"egraph = EGraph([array_api_module])\n",
"\n",
"\n",
"@egraph.register\n",
"def _pushdown_abs(x: NDArray, idx: IndexKey):\n",
" yield rewrite(abs(x)[idx]).to(abs(x[idx]))\n",
"\n",
"\n",
"res = abs(NDArray.var(\"x\"))[NDArray.var(\"idx\")]\n",
"egraph.register(res)\n",
"egraph.run(100)\n",
"egraph.display()\n",
"\n",
"for e in egraph.extract_multiple(res, 10):\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We see here it extracts out the two objects. If we ask it to just extract out the lowest cost one, it will be non-deterministic which is selected.\n",
"\n",
"See [this issue](https://github.com/egraphs-good/egglog/issues/256#issuecomment-1807185387) on the egglog tracker for how this could be resolved.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"\n",
"Another way to resolve this, would be to actually try to define the _semantics_ of these two operations. We can make up a mathematical abstraction for arrays, and define both indexing and the `abs` function in terms of that abstraction. Then when we compose them, we can look at the composition of the abstractions, to see if that normalizing that form to a canonical one can also achieve the predicate pushdown optimization.\n",
"\n",
"In this case, we can pick an abstraction where each array is defined by:\n",
"\n",
"1. A shape `.shape`\n",
"2. A dtype `.dtype`\n",
"3. A mapping from indices to values `x.index(idx)`. Similar to regular indexing `x[idx]`, but only returns the inner value, not a scalar array, so that the defintion is not recursive.\n",
"\n",
"So in this case, our question would be, what's the shape, dtype, and indexed value of `abs(x)[idx]`?\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Resulting shapes:\n",
" abs(NDArray.var(\"x\"))[ndarray_index(NDArray.var(\"idx\"))].shape\n",
"Resulting dtypes:\n",
" abs(NDArray.var(\"x\"))[ndarray_index(NDArray.var(\"idx\"))].dtype\n",
" abs(NDArray.var(\"x\")).dtype\n",
" NDArray.var(\"x\").dtype\n",
"Resulting indexing:\n",
" value_abs(NDArray.var(\"x\").index(translate_index(abs(NDArray.var(\"x\")), ndarray_index(NDArray.var(\"idx\")), an_index())))\n",
" abs(NDArray.var(\"x\"))[ndarray_index(NDArray.var(\"idx\"))].index(an_index())\n",
" abs(NDArray.var(\"x\")).index(translate_index(abs(NDArray.var(\"x\")), ndarray_index(NDArray.var(\"idx\")), an_index()))\n"
]
}
],
"source": [
"egraph = EGraph([array_api_module])\n",
"\n",
"\n",
"@egraph.function(cost=0)\n",
"def value_abs(v: Value) -> Value:\n",
" \"\"\"Absolute value of a scalar value\"\"\"\n",
"\n",
"\n",
"@egraph.register\n",
"def _define_abs(x: NDArray, ti: TupleInt):\n",
" # dtype after taking absolute value is same dtype\n",
" yield rewrite(abs(x).dtype).to(x.dtype)\n",
" # shape after taking absolute value is same shape\n",
" yield rewrite(abs(x).shape).to(x.shape)\n",
" # Indexing into absolute value is same as indexing into original and then taking the absolute value\n",
" yield rewrite(abs(x).index(ti)).to(value_abs(x.index(ti)))\n",
"\n",
"\n",
"@egraph.function(cost=0)\n",
"def translate_index(x: NDArray, y: IndexKey, z: TupleInt) -> TupleInt:\n",
" \"\"\"Translates indexing `z` into `x[y]` into an indexing directly into `x`\"\"\"\n",
"\n",
"\n",
"@egraph.register\n",
"def _define_indexing(x: NDArray, idx: IndexKey, ti: TupleInt):\n",
" # dtype after indexing is same dtype\n",
" yield rewrite(x[idx].dtype).to(x.dtype)\n",
" # indxing is pushed down to source array, after some translation\n",
" yield rewrite(x[idx].index(ti)).to(x.index(translate_index(x, idx, ti)))\n",
" # Shape is more complicated and we will omit for now\n",
"\n",
"\n",
"@egraph.function\n",
"def an_index() -> TupleInt:\n",
" \"\"\"Some index into an array\"\"\"\n",
"\n",
"\n",
"egraph.register(res.shape, res.dtype, res.index(an_index()))\n",
"egraph.run(100)\n",
"egraph.display()\n",
"\n",
"\n",
"print(\"Resulting shapes:\")\n",
"for e in egraph.extract_multiple(res.shape, 10):\n",
" print(\" \", e)\n",
"print(\"Resulting dtypes:\")\n",
"for e in egraph.extract_multiple(res.dtype, 10):\n",
" print(\" \", e)\n",
"print(\"Resulting indexing:\")\n",
"for e in egraph.extract_multiple(res.index(an_index()), 10):\n",
" print(\" \", e)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For the same reason as above, if we extract out the lowest cost one, it will be non-deterministic which is selected (at least for indexing), but if we look at all of them, we can see that the `dtype` is pushed down to the inner value, and also the indexing is as well.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "egglog-python",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}