Show code cell source
from IPython.display import YouTubeVideo
YouTubeVideo("I2ICNT56Rdc")
EGRAPHS Community Call Talk#
Egglog as a tool for building an optimizing composable type safe DSLs in Python
… to help drive theoretical development of e-graphs in conjunction with impacting (large) real world communities.
Now that I have this great e-graph library in Python, what extra mechanisms do I need to make it useful in existing Python code?
This talk will go thorugh a few techniques developed and also point to how by bringing in use cases from scientific Python can help drive further theoretic research
Optimizing Scikit-learn with Numba#
We are going to work through the different pieces needed to optimize a Scikit-learn pipeline using Numba and egglog.
from __future__ import annotations
import sklearn
from sklearn.datasets import make_classification
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# Tell sklearn to treat arrays as following array API
sklearn.set_config(array_api_dispatch=True)
X_np, y_np = make_classification(random_state=0, n_samples=1000000)
# Assumption: I want to optimize calling this many times on data similar to that above
def run_lda(x, y):
lda = LinearDiscriminantAnalysis()
return lda.fit(x, y).transform(x)
We can do this using egglog to generate Python code and Numba to JIT compile it to LLVM, resulting in a speedup:
# The first thing we need to do is create our symbolic arrays and get back our symbolic output
from egglog.exp.array_api import *
X_arr = NDArray.var("X")
assume_dtype(X_arr, X_np.dtype)
assume_shape(X_arr, X_np.shape)
assume_isfinite(X_arr)
y_arr = NDArray.var("y")
assume_dtype(y_arr, y_np.dtype)
assume_shape(y_arr, y_np.shape)
assume_value_one_of(y_arr, tuple(map(int, np.unique(y_np)))) # type: ignore[arg-type]
with EGraph():
res = run_lda(X_arr, y_arr)
res
_NDArray_1 = NDArray.var("X")
assume_dtype(_NDArray_1, DType.float64)
assume_shape(_NDArray_1, TupleInt(Int(1000000)) + TupleInt(Int(20)))
assume_isfinite(_NDArray_1)
_NDArray_2 = NDArray.var("y")
assume_dtype(_NDArray_2, DType.int64)
assume_shape(_NDArray_2, TupleInt(Int(1000000)))
assume_value_one_of(_NDArray_2, TupleValue(Value.int(Int(0))) + TupleValue(Value.int(Int(1))))
_NDArray_3 = asarray(reshape(asarray(_NDArray_2), TupleInt(Int(-1))))
_NDArray_4 = astype(unique_counts(_NDArray_3)[Int(1)], asarray(_NDArray_1).dtype) / NDArray.scalar(Value.float(Float(1000000.0)))
_NDArray_5 = zeros(
TupleInt(unique_inverse(_NDArray_3)[Int(0)].shape[Int(0)]) + TupleInt(asarray(_NDArray_1).shape[Int(1)]),
OptionalDType.some(asarray(_NDArray_1).dtype),
OptionalDevice.some(asarray(_NDArray_1).device),
)
_MultiAxisIndexKey_1 = MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice()))
_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + _MultiAxisIndexKey_1)
_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))
_NDArray_5[_IndexKey_1] = mean(asarray(_NDArray_1)[ndarray_index(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(0))))], _OptionalIntOrTuple_1)
_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(1))) + _MultiAxisIndexKey_1)
_NDArray_5[_IndexKey_2] = mean(asarray(_NDArray_1)[ndarray_index(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(1))))], _OptionalIntOrTuple_1)
_NDArray_6 = unique_values(concat(TupleNDArray(unique_values(asarray(_NDArray_3)))))
_NDArray_7 = concat(
TupleNDArray(asarray(_NDArray_1)[ndarray_index(_NDArray_3 == _NDArray_6[IndexKey.int(Int(0))])] - _NDArray_5[_IndexKey_1])
+ TupleNDArray(asarray(_NDArray_1)[ndarray_index(_NDArray_3 == _NDArray_6[IndexKey.int(Int(1))])] - _NDArray_5[_IndexKey_2]),
OptionalInt.some(Int(0)),
)
_NDArray_8 = std(_NDArray_7, _OptionalIntOrTuple_1)
_NDArray_8[ndarray_index(std(_NDArray_7, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))
_TupleNDArray_1 = svd(
sqrt(asarray(NDArray.scalar(Value.float(Float(1.0) / Float.from_int(asarray(_NDArray_1).shape[Int(0)] - _NDArray_6.shape[Int(0)]))))) * (_NDArray_7 / _NDArray_8), FALSE
)
_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))
_NDArray_9 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_8).T / _TupleNDArray_1[
Int(1)
][IndexKey.slice(_Slice_1)]
_TupleNDArray_2 = svd(
(
sqrt(
(NDArray.scalar(Value.int(asarray(_NDArray_1).shape[Int(0)])) * _NDArray_4)
* NDArray.scalar(Value.float(Float(1.0) / Float.from_int(_NDArray_6.shape[Int(0)] - Int(1))))
)
* (_NDArray_5 - (_NDArray_4 @ _NDArray_5)).T
).T
@ _NDArray_9,
FALSE,
)
(
(asarray(_NDArray_1) - (_NDArray_4 @ _NDArray_5))
@ (
_NDArray_9
@ _TupleNDArray_2[Int(2)].T[
IndexKey.multi_axis(
_MultiAxisIndexKey_1
+ MultiAxisIndexKey(
MultiAxisIndexKeyItem.slice(
Slice(
OptionalInt.none,
OptionalInt.some(
sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))
.to_value()
.to_int
),
)
)
)
)
]
)
)[IndexKey.multi_axis(_MultiAxisIndexKey_1 + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(_NDArray_6.shape[Int(0)] - Int(1))))))]
In order to run this, scikit-learn treated these objects as “array like”, meaning they conformed to the Array API.
Conversions: From Python to egglog#
Use conversions if you want your egglog API to be called with existing Python objects, without manually upcasting them
We will see this in our example:
class LinearDiscriminantAnalysis:
...
def fit(self, X, y):
...
_, cnts = xp.unique_counts(y) # non-negative ints
self.priors_ = xp.astype(cnts, X.dtype) / float(y.shape[0])
Ends up resulting in this expression:
astype(unique_counts(_NDArray_3)[Int(1)], asarray(_NDArray_1).dtype) / NDArray.scalar(Value.float(Float(1000000.0)))
How?
We have exposed a global conversion logic, where if you pass an arg to egglog and it isn’t the correct type, it will try to upcast the arg to the required egglog type.
There is a graph of all conversions and it will find the shortest path from the input to the desired type and automatically upcast to that.
Example#
For example, in indexing, if we do a slice (i.e. 1:10:2
), we convert this to our custom egglog Slice
expressions:
class Slice(Expr):
def __init__(
self,
start: OptionalInt = OptionalInt.none,
stop: OptionalInt = OptionalInt.none,
step: OptionalInt = OptionalInt.none,
) -> None: ...
converter(
slice,
Slice,
lambda x: Slice(
convert(x.start, OptionalInt),
convert(x.stop, OptionalInt),
convert(x.step, OptionalInt),
),
)
class A(Expr):
def __init__(self) -> None: ...
def __getitem__(self, s: Slice) -> Int: ...
A()[:1:2] # Pytohn desugars this to A()[slice(None, 1, 2)]
A()[Slice(OptionalInt.none, OptionalInt.some(Int(1)), OptionalInt.some(Int(2)))]
Preserved Methods#
If you need your egglog objects to interact with Python control flow, you can use preserved methods to stop, compile, and return an eager result to Python
In the fit
function in sklearn, there are complicated analysis that must be done eagerly, like this one, which
depends on knowing the priors, which are based on the counts of the classes in the training data, which we provided:
class LinearDiscriminantAnalysis:
def fit(self, x, y):
...
if xp.abs(xp.sum(self.priors_) - 1.0) > 1e-5:
warnings.warn("The priors do not sum to 1. Renormalizing", UserWarning)
self.priors_ = self.priors_ / self.priors_.sum()
That is why we have to provide the metadata about the arrays, so we can reduce this expression to a boolean, using some interval analysis:
_NDArray_1 = NDArray.var("y")
assume_dtype(_NDArray_1, DType.int64)
assume_shape(_NDArray_1, TupleInt(Int(1000000)))
assume_value_one_of(_NDArray_1, TupleValue(Value.int(Int(0))) + TupleValue(Value.int(Int(1))))
_NDArray_2 = NDArray.var("X")
assume_dtype(_NDArray_2, DType.float64)
assume_shape(_NDArray_2, TupleInt(Int(1000000)) + TupleInt(Int(20)))
assume_isfinite(_NDArray_2)
(
abs(
sum(astype(unique_counts(asarray(reshape(asarray(_NDArray_1), TupleInt(Int(-1)))))[Int(1)], asarray(_NDArray_2).dtype) / NDArray.scalar(Value.float(Float(1000000.0))))
- NDArray.scalar(Value.float(Float(1.0)))
)
> NDArray.scalar(Value.float(Float(1e-05)))
).to_value().to_bool.bool
But how does it move through the control flow if the expression is lazy? We can implement “preserved methods” which can evaluate an expression eagerly by adding it to the EGraph and evalutating it:
Example#
class Boolean(Expr):
# Can be constructed from and convert to a primitive egglog bool:
def __init__(self, b: BoolLike) -> None: ...
@property
def bool(self) -> Bool: ...
# support boolean ops
def __and__(self, other: Boolean) -> Boolean: ...
# Can be treated like a Python bool
@method(preserve=True)
def __bool__(self) -> bool:
egraph = EGraph()
egraph.register(self)
egraph.run(bool_rewrites.saturate())
return egraph.eval(self.bool)
x = var("x", Boolean)
y = var("y", Bool)
bool_rewrites = ruleset(
rule(eq(x).to(Boolean(y))).then(set_(x.bool).to(y)),
rewrite(Boolean(True) & Boolean(True)).to(Boolean(True)),
)
expr = Boolean(True) & Boolean(True)
expr
Boolean(True) & Boolean(True)
if expr:
print("yep it's true")
yep it's true
Mutations#
Mark a function or method as mutating the first arg, to translate it to pure function, but which acts imperative.
Another pattern that comes up a lot in Python is methods that mutate their arguments. But egglog is a pure functional language, so how do we support that?
Well we can convert functions that mutate an arg into one that returns a modified value of that argument. That way, we can keep using it with existing imperative methods and things work as they should.
For example, arrays support __setitem__
, and this is used by scikit-learn:
class LinearDiscriminantAnalysis:
...
def _solve_svd(self, X, y):
...
# 1) within (univariate) scaling by with classes std-dev
std = xp.std(Xc, axis=0)
# avoid division by zero in normalization
std[std == 0] = 1.0
This will be translated to the following expressions, where there will be a new array created in the graph for the modified version:
_NDArray_8 = std(_NDArray_7, _OptionalIntOrTuple_1)
_NDArray_8[ndarray_index(std(_NDArray_7, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))
Example#
We can see a simpler example of this below:
class ListOfInts(Expr):
def __init__(self) -> None: ...
def __getitem__(self, i: i64Like) -> Int: ...
def __setitem__(self, i: i64Like, v: Int) -> None: ...
xs = ListOfInts()
xs[0] = Int(1)
egraph = EGraph()
egraph.register(xs[0])
egraph.display()
Subsumption#
mark a rewrite as subsumed to replace a smaller expression with a larger one
Now that we have a program, what do we do with it? Well first, we can optimize it, running rewrites, including those to translate from Numba forms to others.
We have added “subsumption” to egglog, to support directional rewrites, so that the left hand side is not extractable and not matchable. This is handy when we want to extract a value with more expressions or a higher cost, in a particular instance:
@array_api_numba_ruleset.register
def _mean(y: NDArray, x: NDArray, i: Int):
axis = OptionalIntOrTuple.some(IntOrTuple.int(i))
res = sum(x, axis) / NDArray.scalar(Value.int(x.shape[i]))
yield rewrite(mean(x, axis, FALSE), subsume=True).to(res)
yield rewrite(mean(x, axis, TRUE), subsume=True).to(expand_dims(res, i))
We can optimize this with the numba rules and we can see this rule take place in the _NDArray_9
line:
from egglog.exp.array_api_numba import array_api_numba_schedule
simplified_res = EGraph().simplify(res, array_api_numba_schedule)
simplified_res
_NDArray_1 = NDArray.var("X")
assume_dtype(_NDArray_1, DType.float64)
assume_shape(_NDArray_1, TupleInt(Int(1000000)) + TupleInt(Int(20)))
assume_isfinite(_NDArray_1)
_NDArray_2 = NDArray.var("y")
assume_dtype(_NDArray_2, DType.int64)
assume_shape(_NDArray_2, TupleInt(Int(1000000)))
assume_value_one_of(_NDArray_2, TupleValue(Value.int(Int(0))) + TupleValue(Value.int(Int(1))))
_NDArray_3 = astype(
NDArray.vector(TupleValue(sum(_NDArray_2 == NDArray.scalar(Value.int(Int(0)))).to_value()) + TupleValue(sum(_NDArray_2 == NDArray.scalar(Value.int(Int(1)))).to_value())),
DType.float64,
) / NDArray.scalar(Value.float(Float.rational(Rational(1000000, 1))))
_NDArray_4 = zeros(TupleInt(Int(2)) + TupleInt(Int(20)), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))
_MultiAxisIndexKey_1 = MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice()))
_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + _MultiAxisIndexKey_1)
_NDArray_5 = _NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(0))))]
_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))
_NDArray_4[_IndexKey_1] = sum(_NDArray_5, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_5.shape[Int(0)]))
_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(1))) + _MultiAxisIndexKey_1)
_NDArray_6 = _NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(1))))]
_NDArray_4[_IndexKey_2] = sum(_NDArray_6, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_6.shape[Int(0)]))
_NDArray_7 = concat(TupleNDArray(_NDArray_5 - _NDArray_4[_IndexKey_1]) + TupleNDArray(_NDArray_6 - _NDArray_4[_IndexKey_2]), OptionalInt.some(Int(0)))
_NDArray_8 = square(_NDArray_7 - expand_dims(sum(_NDArray_7, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_7.shape[Int(0)]))))
_NDArray_9 = sqrt(sum(_NDArray_8, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_8.shape[Int(0)])))
_NDArray_10 = copy(_NDArray_9)
_NDArray_10[ndarray_index(_NDArray_9 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float.rational(Rational(1, 1))))
_TupleNDArray_1 = svd(sqrt(NDArray.scalar(Value.float(Float.rational(Rational(1, 999998))))) * (_NDArray_7 / _NDArray_10), FALSE)
_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))
_NDArray_11 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_10).T / _TupleNDArray_1[
Int(1)
][IndexKey.slice(_Slice_1)]
_TupleNDArray_2 = svd(
(sqrt((NDArray.scalar(Value.int(Int(1000000))) * _NDArray_3) * NDArray.scalar(Value.float(Float.rational(Rational(1, 1))))) * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T).T
@ _NDArray_11,
FALSE,
)
(
(_NDArray_1 - (_NDArray_3 @ _NDArray_4))
@ (
_NDArray_11
@ _TupleNDArray_2[Int(2)].T[
IndexKey.multi_axis(
_MultiAxisIndexKey_1
+ MultiAxisIndexKey(
MultiAxisIndexKeyItem.slice(
Slice(
OptionalInt.none,
OptionalInt.some(
sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))
.to_value()
.to_int
),
)
)
)
)
]
)
)[IndexKey.multi_axis(_MultiAxisIndexKey_1 + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(1))))))]
Program Gen#
Generate an imperative program from your e-graph with replacement rules that walk the graph in a fixed order
Now that we have a program, what do we do with it?
Well we showed how we can use eager evaluation to get a result, but what if we don’t want to do the computation in egglog, but instead export a program so we can execute that back in Python or in this case feed it to Python?
Well in this case we have designed a Program
object which we can use to convert a funtional egglog expression back to imperative Python code:
from egglog.exp.array_api_program_gen import *
egraph = EGraph()
fn_program = egraph.let(
"fn_program",
ndarray_function_two(simplified_res, NDArray.var("X"), NDArray.var("y")),
)
egraph.run(array_api_program_gen_schedule)
fn = egraph.eval(fn_program.py_object)
fn
<function __fn(X, y)>
import inspect
print(inspect.getsource(fn))
def __fn(X, y):
assert X.dtype == np.dtype(np.float64)
assert X.shape == (1000000, 20,)
assert np.all(np.isfinite(X))
assert y.dtype == np.dtype(np.int64)
assert y.shape == (1000000,)
assert set(np.unique(y)) == set((0, 1,))
_0 = y == np.array(0)
_1 = np.sum(_0)
_2 = y == np.array(1)
_3 = np.sum(_2)
_4 = np.array((_1, _3,)).astype(np.dtype(np.float64))
_5 = _4 / np.array(float(1000000))
_6 = np.zeros((2, 20,), dtype=np.dtype(np.float64))
_7 = np.sum(X[_0], axis=0)
_8 = _7 / np.array(X[_0].shape[0])
_6[0, :] = _8
_9 = np.sum(X[_2], axis=0)
_10 = _9 / np.array(X[_2].shape[0])
_6[1, :] = _10
_11 = _5 @ _6
_12 = X - _11
_13 = np.sqrt(np.array(float(1 / 999998)))
_14 = X[_0] - _6[0, :]
_15 = X[_2] - _6[1, :]
_16 = np.concatenate((_14, _15,), axis=0)
_17 = np.sum(_16, axis=0)
_18 = _17 / np.array(_16.shape[0])
_19 = np.expand_dims(_18, 0)
_20 = _16 - _19
_21 = np.square(_20)
_22 = np.sum(_21, axis=0)
_23 = _22 / np.array(_21.shape[0])
_24 = np.sqrt(_23)
_25 = _24 == np.array(0)
_24[_25] = np.array(float(1))
_26 = _16 / _24
_27 = _13 * _26
_28 = np.linalg.svd(_27, full_matrices=False)
_29 = _28[1] > np.array(0.0001)
_30 = _29.astype(np.dtype(np.int32))
_31 = np.sum(_30)
_32 = _28[2][:_31, :] / _24
_33 = _32.T / _28[1][:_31]
_34 = np.array(1000000) * _5
_35 = _34 * np.array(float(1))
_36 = np.sqrt(_35)
_37 = _6 - _11
_38 = _36 * _37.T
_39 = _38.T @ _33
_40 = np.linalg.svd(_39, full_matrices=False)
_41 = np.array(0.0001) * _40[1][0]
_42 = _40[1] > _41
_43 = _42.astype(np.dtype(np.int32))
_44 = np.sum(_43)
_45 = _33 @ _40[2].T[:, :_44]
_46 = _12 @ _45
return _46[:, :1]
From there we can complete our work, by optimizing with numba and we can call with our original values:
from numba import njit
njit(fn)(X_np, y_np)
/tmp/egglog-826e83db-ad16-4a78-84a8-5b3f5644b50a.py:56: NumbaPerformanceWarning: '@' is faster on contiguous arrays, called on (Array(float64, 2, 'C', False, aligned=True), Array(float64, 2, 'A', False, aligned=True))
_45 = _33 @ _40[2].T[:, :_44]
array([[ 0.64233002],
[ 0.63661245],
[-1.603293 ],
...,
[-1.1506433 ],
[ 0.71687176],
[-1.51119579]])
These program rewrites work by first translating the code into an intermediate IR of just program assignments and statements, and then turning this into source. It does this by walking the graph first top to bottom, recording the first parent that saw every child. Then it goes from bottom to top, building up a larger set of statements as well as an expression representing each node. If a node has already been emitted somewhere else, we record that, and we always wait for all the children to complete before moving forward. That way, we can enforce some ordering and we know that everything will be emitted only once.
Example#
Here is a small example, we might want to compile, let’s say we want a function like this:
def __fn(x, y)
x[1] = 10
z = x + x
return sum(z) + y
We can build up a functional program for this and then compile it to Python source:
from egglog.exp.program_gen import *
x = Program("x", is_identifier=True)
y = Program("y", is_identifier=True)
# To reference x, we need to first emit the statement
x_modified = Program("x").statement(x + "[1] = 10")
z = (x_modified + " + " + x_modified).assign()
res = Program("sum(") + z + ") + " + y
fn = res.function_two(x, y)
egraph = EGraph()
egraph.register(fn.compile())
egraph.run(program_gen_ruleset.saturate())
print(egraph.eval(fn.statements))
def __fn(x, y):
x[1] = 10
_0 = x + x
return sum(_0) + y
This happens, by first going top to bottom with the compile
, which will put a total ordering on all nodes, by defining one, and only one, parent expression for each expression.
Then, from the bottom up, for each node we compute an expression string and a list of statements string.
egraph
PyObject - Python objects in EGraphs#
We can also add Python objects directly to the e-graph as primitives and run rewrite rules that call back into Python
Future Work#
The next milestone use case is to be able to optimize functional array programs and rewrite them.
To implement this we need to at least support functions as values and ideally also generic types.
Example#
There is a concrete example provided by Siu from the Numba project.
We would want users to be able to write code like this:
def linalg_norm_loopnest_egglog(X: enp.NDArray, axis: enp.TupleInt) -> enp.NDArray:
# peel off the outer shape for result array
outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
# get only the inner shape for reduction
reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()
return enp.NDArray.from_fn(
outshape,
X.dtype,
lambda k: enp.sqrt(
LoopNestAPI.from_tuple(reduce_axis)
.unwrap()
.reduce(lambda carry, i: carry + enp.real(enp.conj(x := X[i + k]) * x), init=0.0)
).to_value(),
)
Which would then be rewritten to:
def linalg_norm_array_api(X: enp.NDArray, axis: enp.TupleInt) -> enp.NDArray:
outdim = enp.range_(X.ndim).filter(lambda x: ~axis.contains(x))
outshape = convert(convert(X.shape, enp.NDArray)[outdim], enp.TupleInt)
row_axis, col_axis = axis
return enp.NDArray.from_fn(
outshape,
X.dtype,
lambda k: enp.sqrt(
enp.int_product(enp.range_(X.shape[row_axis]), enp.range_(X.shape[col_axis]))
.map_to_ndarray(lambda rc: enp.real(enp.conj(x := X[rc + k]) * x))
.sum()
).to_value(),
)
And finally could be lowered to Python as:
def linalg_norm_low_level(
X: np.ndarray[tuple, np.dtype[np.float64]], axis: tuple[int, int]
) -> np.ndarray[tuple, np.dtype[np.float64]]:
# # If X ndim>=3 and axis is a 2-tuple
assert X.ndim >= 3
assert len(axis) == 2
# Y - 2
outdim = [dim for dim in range(X.ndim) if dim not in axis]
outshape = tuple(np.asarray(X.shape)[outdim])
res = np.zeros(outshape, dtype=X.dtype)
row_axis, col_axis = axis
for k in np.ndindex(outshape):
tmp = 0.0
for row in range(X.shape[row_axis]):
for col in range(X.shape[col_axis]):
idx = (row, col, *k)
x = X[idx]
tmp += (x.conj() * x).real
res[k] = np.sqrt(tmp)
return res
Conclusion#
In this talk I have gone through some details of what is needed to connect data science users to egglog:
Overall, the idea is that if we can get egglog in more users hands, in particular for data intensive workloads where the tradeoff of time for pre-computation is worth it, than this can help drive exciting future research directions and also build meaningful useful tools for the scientific open source ecosystem in Python.
If you are building DSLs in Python, or more generally want to play with e-graphs, try out egglog-python
!
Around on the e-graphs Zulip for any questions.