Indexing pushdown#
I got this question today, and I thought I would write up some examples to explain the current state of things, for this sort of indexing pushdown:
Q: How would I do a exp(vec)[idx] -> exp(vec[idx])
rewrite?
A: You can easily write this rewrite to add both expressions to the graph, but it’s currently difficult to extract out the right expression over the left one:
from __future__ import annotations
from egglog.exp.array_api import *
egraph = EGraph([array_api_module])
@egraph.register
def _pushdown_abs(x: NDArray, idx: IndexKey):
yield rewrite(abs(x)[idx]).to(abs(x[idx]))
res = abs(NDArray.var("x"))[NDArray.var("idx")]
egraph.register(res)
egraph.run(100)
egraph.display()
for e in egraph.extract_multiple(res, 10):
print(e)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[1], line 5
1 from __future__ import annotations
3 from egglog.exp.array_api import *
----> 5 egraph = EGraph([array_api_module])
8 @egraph.register
9 def _pushdown_abs(x: NDArray, idx: IndexKey):
10 yield rewrite(abs(x)[idx]).to(abs(x[idx]))
NameError: name 'array_api_module' is not defined
We see here it extracts out the two objects. If we ask it to just extract out the lowest cost one, it will be non-deterministic which is selected.
See this issue on the egglog tracker for how this could be resolved.
Another way to resolve this, would be to actually try to define the semantics of these two operations. We can make up a mathematical abstraction for arrays, and define both indexing and the abs
function in terms of that abstraction. Then when we compose them, we can look at the composition of the abstractions, to see if that normalizing that form to a canonical one can also achieve the predicate pushdown optimization.
In this case, we can pick an abstraction where each array is defined by:
A shape
.shape
A dtype
.dtype
A mapping from indices to values
x.index(idx)
. Similar to regular indexingx[idx]
, but only returns the inner value, not a scalar array, so that the defintion is not recursive.
So in this case, our question would be, what’s the shape, dtype, and indexed value of abs(x)[idx]
?
egraph = EGraph([array_api_module])
@egraph.function(cost=0)
def value_abs(v: Value) -> Value:
"""Absolute value of a scalar value"""
@egraph.register
def _define_abs(x: NDArray, ti: TupleInt):
# dtype after taking absolute value is same dtype
yield rewrite(abs(x).dtype).to(x.dtype)
# shape after taking absolute value is same shape
yield rewrite(abs(x).shape).to(x.shape)
# Indexing into absolute value is same as indexing into original and then taking the absolute value
yield rewrite(abs(x).index(ti)).to(value_abs(x.index(ti)))
@egraph.function(cost=0)
def translate_index(x: NDArray, y: IndexKey, z: TupleInt) -> TupleInt:
"""Translates indexing `z` into `x[y]` into an indexing directly into `x`"""
@egraph.register
def _define_indexing(x: NDArray, idx: IndexKey, ti: TupleInt):
# dtype after indexing is same dtype
yield rewrite(x[idx].dtype).to(x.dtype)
# indxing is pushed down to source array, after some translation
yield rewrite(x[idx].index(ti)).to(x.index(translate_index(x, idx, ti)))
# Shape is more complicated and we will omit for now
@egraph.function
def an_index() -> TupleInt:
"""Some index into an array"""
egraph.register(res.shape, res.dtype, res.index(an_index()))
egraph.run(100)
egraph.display()
print("Resulting shapes:")
for e in egraph.extract_multiple(res.shape, 10):
print(" ", e)
print("Resulting dtypes:")
for e in egraph.extract_multiple(res.dtype, 10):
print(" ", e)
print("Resulting indexing:")
for e in egraph.extract_multiple(res.index(an_index()), 10):
print(" ", e)
Resulting shapes:
abs(NDArray.var("x"))[ndarray_index(NDArray.var("idx"))].shape
Resulting dtypes:
abs(NDArray.var("x"))[ndarray_index(NDArray.var("idx"))].dtype
abs(NDArray.var("x")).dtype
NDArray.var("x").dtype
Resulting indexing:
value_abs(NDArray.var("x").index(translate_index(abs(NDArray.var("x")), ndarray_index(NDArray.var("idx")), an_index())))
abs(NDArray.var("x"))[ndarray_index(NDArray.var("idx"))].index(an_index())
abs(NDArray.var("x")).index(translate_index(abs(NDArray.var("x")), ndarray_index(NDArray.var("idx")), an_index()))
For the same reason as above, if we extract out the lowest cost one, it will be non-deterministic which is selected (at least for indexing), but if we look at all of them, we can see that the dtype
is pushed down to the inner value, and also the indexing is as well.