Optimizing Scikit-Learn with Array API and Numba#

In this tutorial, we will walk through increasing the performance of your scikit-learn code using egglog and numba.

One of the goals of egglog is to be used by other scientific computing libraries to create flexible APIs, which conform to existing user expectations but allow a greater flexability in how they perform execution.

To work towards that we goal, we have built an prototype of a Array API standard conformant API that can be used with Scikit-Learn’s experimental Array API support, to optimize it using Numba.

Normal execution#

We can create a test data set and use LDA to create a classification. Then we can run it on the dataset, to return the estimated classification for out test data:

from sklearn import config_context
from sklearn.datasets import make_classification
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis


def run_lda(x, y):
    with config_context(array_api_dispatch=True):
        lda = LinearDiscriminantAnalysis()
        return lda.fit(x, y).transform(x)


X_np, y_np = make_classification(random_state=0, n_samples=1000000)
run_lda(X_np, y_np)
array([[ 0.64233002],
       [ 0.63661245],
       [-1.603293  ],
       ...,
       [-1.1506433 ],
       [ 0.71687176],
       [-1.51119579]])

Building our inputs#

Now, we can try executing it with egglog instead. In this mode, we aren’t actually passing in any particular NDArray, but instead just using variables to represent the X and Y values.

These are defined in the egglog.exp.array_api module, as typed values:

@array_api_module.class_
class NDArray(Expr):
    @array_api_module.method(cost=200)
    @classmethod
    def var(cls, name: StringLike) -> NDArray: ...

    @property
    def shape(self) -> TupleInt: ...

    ...

@array_api_module.function(mutates_first_arg=True)
def assume_shape(x: NDArray, shape: TupleInt) -> None: ...

We can use these functon to provides some metadata about the arguments as well:

from copy import copy

from egglog.exp.array_api import (
    NDArray,
    assume_dtype,
    assume_shape,
    assume_isfinite,
    assume_value_one_of,
)

X_arr = NDArray.var("X")
X_orig = copy(X_arr)

assume_dtype(X_arr, X_np.dtype)
assume_shape(X_arr, X_np.shape)
assume_isfinite(X_arr)

y_arr = NDArray.var("y")
y_orig = copy(y_arr)

assume_dtype(y_arr, y_np.dtype)
assume_shape(y_arr, y_np.shape)
assume_value_one_of(y_arr, (0, 1))

While most of the execution can be deferred, every time sklearn triggers Python control flow (if, for, etc), we need to execute eagerly and be able to give a definate value. For example, scikit-learn checks to makes sure that the number of samples we pass in is greater than the number of unique classes:

class LinearDiscriminantAnalysis(...):
    ...
    def fit(self, X, y):
        ...
        self.classes_ = unique_labels(y)
        n_samples, _ = X.shape
        n_classes = self.classes_.shape[0]

        if n_samples == n_classes:
            raise ValueError(
                "The number of samples must be more than the number of classes."
            )
        ...

Without the assumptions above, we wouldn’t know if the conditional is true or false. So we provide just enough information for sklearn to finish executing and give us a result.

Getting a result#

We can now run our lda function with our inputs, which have the constraints about them saved, and see the graph which will show all of the intermerdiate results we had to compute to get our answer:

from egglog import EGraph
from egglog.exp.array_api import array_api_module

with EGraph([array_api_module]) as egraph:
    X_r2 = run_lda(X_arr, y_arr)
    egraph.display(n_inline_leaves=3, split_primitive_outputs=True)
X_r2
../_images/811c603f831d9fd2039944714926c7862d0720e09d43f06b3dbc25c3f097490e.svg
_NDArray_1 = NDArray.var("X")
assume_dtype(_NDArray_1, DType.float64)
assume_shape(_NDArray_1, TupleInt(Int(1000000)) + TupleInt(Int(20)))
assume_isfinite(_NDArray_1)
_NDArray_2 = NDArray.var("y")
assume_dtype(_NDArray_2, DType.int64)
assume_shape(_NDArray_2, TupleInt(Int(1000000)))
assume_value_one_of(_NDArray_2, TupleValue(Value.int(Int(0))) + TupleValue(Value.int(Int(1))))
_NDArray_3 = asarray(reshape(asarray(_NDArray_2), TupleInt(Int(-1))))
_NDArray_4 = astype(unique_counts(_NDArray_3)[Int(1)], asarray(_NDArray_1).dtype) / NDArray.scalar(Value.float(Float(1000000.0)))
_NDArray_5 = zeros(
    TupleInt(unique_inverse(_NDArray_3)[Int(0)].shape[Int(0)]) + TupleInt(asarray(_NDArray_1).shape[Int(1)]),
    OptionalDType.some(asarray(_NDArray_1).dtype),
    OptionalDevice.some(asarray(_NDArray_1).device),
)
_MultiAxisIndexKey_1 = MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice()))
_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + _MultiAxisIndexKey_1)
_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))
_NDArray_5[_IndexKey_1] = mean(asarray(_NDArray_1)[ndarray_index(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(0))))], _OptionalIntOrTuple_1)
_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(1))) + _MultiAxisIndexKey_1)
_NDArray_5[_IndexKey_2] = mean(asarray(_NDArray_1)[ndarray_index(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(1))))], _OptionalIntOrTuple_1)
_NDArray_6 = unique_values(concat(TupleNDArray(unique_values(asarray(_NDArray_3)))))
_NDArray_7 = concat(
    TupleNDArray(asarray(_NDArray_1)[ndarray_index(_NDArray_3 == _NDArray_6[IndexKey.int(Int(0))])] - _NDArray_5[_IndexKey_1])
    + TupleNDArray(asarray(_NDArray_1)[ndarray_index(_NDArray_3 == _NDArray_6[IndexKey.int(Int(1))])] - _NDArray_5[_IndexKey_2]),
    OptionalInt.some(Int(0)),
)
_NDArray_8 = std(_NDArray_7, _OptionalIntOrTuple_1)
_NDArray_8[ndarray_index(std(_NDArray_7, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))
_TupleNDArray_1 = svd(
    sqrt(asarray(NDArray.scalar(Value.float(Float(1.0) / Float.from_int(asarray(_NDArray_1).shape[Int(0)] - _NDArray_6.shape[Int(0)]))))) * (_NDArray_7 / _NDArray_8), FALSE
)
_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))
_NDArray_9 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_8).T / _TupleNDArray_1[
    Int(1)
][IndexKey.slice(_Slice_1)]
_TupleNDArray_2 = svd(
    (
        sqrt(
            (NDArray.scalar(Value.int(asarray(_NDArray_1).shape[Int(0)])) * _NDArray_4)
            * NDArray.scalar(Value.float(Float(1.0) / Float.from_int(_NDArray_6.shape[Int(0)] - Int(1))))
        )
        * (_NDArray_5 - (_NDArray_4 @ _NDArray_5)).T
    ).T
    @ _NDArray_9,
    FALSE,
)
(
    (asarray(_NDArray_1) - (_NDArray_4 @ _NDArray_5))
    @ (
        _NDArray_9
        @ _TupleNDArray_2[Int(2)].T[
            IndexKey.multi_axis(
                _MultiAxisIndexKey_1
                + MultiAxisIndexKey(
                    MultiAxisIndexKeyItem.slice(
                        Slice(
                            OptionalInt.none,
                            OptionalInt.some(
                                sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))
                                .to_value()
                                .to_int
                            ),
                        )
                    )
                )
            )
        ]
    )
)[IndexKey.multi_axis(_MultiAxisIndexKey_1 + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(_NDArray_6.shape[Int(0)] - Int(1))))))]

We now have extracted out a program which is semantically equivalent to the original call! One thing you might notice is that the expression has more types than customary NumPy code. Every object is lifted into a strongly typed egglog class. This is so that when we run optimizations, we know the types of all the objects. It still is compatible with normal Python objects, but they are converted when they are passed as argument.

Optimizing our result#

Now that we have the an expression, we can run our rewrite rules to “optimize” it, extracting out the lowest cost (smallest) expression afterword:

egraph = EGraph([array_api_module])
egraph.register(X_r2)
egraph.run(10000)
X_r2_optimized = egraph.extract(X_r2)
X_r2_optimized
_NDArray_1 = NDArray.var("X")
assume_dtype(_NDArray_1, DType.float64)
assume_shape(_NDArray_1, TupleInt(Int(1000000)) + TupleInt(Int(20)))
assume_isfinite(_NDArray_1)
_NDArray_2 = NDArray.var("y")
assume_dtype(_NDArray_2, DType.int64)
assume_shape(_NDArray_2, TupleInt(Int(1000000)))
assume_value_one_of(_NDArray_2, TupleValue(Value.int(Int(0))) + TupleValue(Value.int(Int(1))))
_NDArray_3 = astype(unique_counts(_NDArray_2)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(1000000.0)))
_NDArray_4 = zeros(TupleInt(Int(2)) + TupleInt(Int(20)), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))
_MultiAxisIndexKey_1 = MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice()))
_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + _MultiAxisIndexKey_1)
_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))
_NDArray_4[_IndexKey_1] = mean(_NDArray_1[ndarray_index(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(0))))], _OptionalIntOrTuple_1)
_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(1))) + _MultiAxisIndexKey_1)
_NDArray_4[_IndexKey_2] = mean(_NDArray_1[ndarray_index(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(1))))], _OptionalIntOrTuple_1)
_NDArray_5 = concat(
    TupleNDArray(_NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(0))))] - _NDArray_4[_IndexKey_1])
    + TupleNDArray(_NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(1))))] - _NDArray_4[_IndexKey_2]),
    OptionalInt.some(Int(0)),
)
_NDArray_6 = std(_NDArray_5, _OptionalIntOrTuple_1)
_NDArray_6[ndarray_index(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))
_TupleNDArray_1 = svd(sqrt(NDArray.scalar(Value.float(Float(1.0) / Float.from_int(Int(999998))))) * (_NDArray_5 / _NDArray_6), FALSE)
_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))
_NDArray_7 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_6).T / _TupleNDArray_1[
    Int(1)
][IndexKey.slice(_Slice_1)]
_TupleNDArray_2 = svd(
    (sqrt((NDArray.scalar(Value.int(Int(1000000))) * _NDArray_3) * NDArray.scalar(Value.float(Float(1.0)))) * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T).T @ _NDArray_7, FALSE
)
(
    (_NDArray_1 - (_NDArray_3 @ _NDArray_4))
    @ (
        _NDArray_7
        @ _TupleNDArray_2[Int(2)].T[
            IndexKey.multi_axis(
                _MultiAxisIndexKey_1
                + MultiAxisIndexKey(
                    MultiAxisIndexKeyItem.slice(
                        Slice(
                            OptionalInt.none,
                            OptionalInt.some(
                                sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))
                                .to_value()
                                .to_int
                            ),
                        )
                    )
                )
            )
        ]
    )
)[IndexKey.multi_axis(_MultiAxisIndexKey_1 + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(1))))))]

We see that for example expressions that referenced the shape of our input arrays have been resolved to their values.

We can also take a look at the e-graph itself, even though it’s quite large, where we can see that equivalent expressions show up in the same group, or “e-class”:

egraph.display(n_inline_leaves=3, split_primitive_outputs=True)
../_images/c0d5646af9f86ce2a15672a7a83fcbd0122f5c204cd7a1e64b9113006e282c6c.svg

Translating for Numba#

We are getting closer to a form we could translate back to Numba, but we have to make a few changes. Numba doesn’t support the axis keyword for mean or std, but it does support it for sum, so we have to translate all forms from one to the other, with a rule like this (defined in egglog.exp.array_api_numba):

axis = OptionalIntOrTuple.some(IntOrTuple.int(i))
rewrite(std(x, axis)).to(sqrt(mean(square(abs(x - mean(x, axis, keepdims=TRUE))), axis)))

We can run those additional rewrites now to get a new extracted version

from egglog.exp.array_api_numba import array_api_numba_module

egraph = EGraph([array_api_numba_module])
egraph.register(X_r2_optimized)
egraph.run(10000)
X_r2_numba = egraph.extract(X_r2_optimized)
X_r2_numba
_NDArray_1 = NDArray.var("X")
assume_dtype(_NDArray_1, DType.float64)
assume_shape(_NDArray_1, TupleInt(Int(1000000)) + TupleInt(Int(20)))
assume_isfinite(_NDArray_1)
_NDArray_2 = NDArray.var("y")
assume_dtype(_NDArray_2, DType.int64)
assume_shape(_NDArray_2, TupleInt(Int(1000000)))
assume_value_one_of(_NDArray_2, TupleValue(Value.int(Int(0))) + TupleValue(Value.int(Int(1))))
_NDArray_3 = astype(
    NDArray.vector(TupleValue(sum(_NDArray_2 == NDArray.scalar(Value.int(Int(0)))).to_value()) + TupleValue(sum(_NDArray_2 == NDArray.scalar(Value.int(Int(1)))).to_value())),
    DType.float64,
) / NDArray.scalar(Value.float(Float(1000000.0)))
_NDArray_4 = zeros(TupleInt(Int(2)) + TupleInt(Int(20)), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))
_MultiAxisIndexKey_1 = MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice()))
_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + _MultiAxisIndexKey_1)
_NDArray_5 = _NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(0))))]
_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))
_NDArray_4[_IndexKey_1] = sum(_NDArray_5, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_5.shape[Int(0)]))
_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(1))) + _MultiAxisIndexKey_1)
_NDArray_6 = _NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(1))))]
_NDArray_4[_IndexKey_2] = sum(_NDArray_6, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_6.shape[Int(0)]))
_NDArray_7 = concat(TupleNDArray(_NDArray_5 - _NDArray_4[_IndexKey_1]) + TupleNDArray(_NDArray_6 - _NDArray_4[_IndexKey_2]), OptionalInt.some(Int(0)))
_NDArray_8 = square(_NDArray_7 - expand_dims(sum(_NDArray_7, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_7.shape[Int(0)]))))
_NDArray_9 = sqrt(sum(_NDArray_8, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_8.shape[Int(0)])))
_NDArray_10 = copy(_NDArray_9)
_NDArray_10[ndarray_index(_NDArray_9 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))
_TupleNDArray_1 = svd(sqrt(NDArray.scalar(Value.float(Float(1.0) / Float.from_int(Int(999998))))) * (_NDArray_7 / _NDArray_10), FALSE)
_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))
_NDArray_11 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_10).T / _TupleNDArray_1[
    Int(1)
][IndexKey.slice(_Slice_1)]
_TupleNDArray_2 = svd(
    (sqrt((NDArray.scalar(Value.int(Int(1000000))) * _NDArray_3) * NDArray.scalar(Value.float(Float(1.0)))) * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T).T @ _NDArray_11, FALSE
)
(
    (_NDArray_1 - (_NDArray_3 @ _NDArray_4))
    @ (
        _NDArray_11
        @ _TupleNDArray_2[Int(2)].T[
            IndexKey.multi_axis(
                _MultiAxisIndexKey_1
                + MultiAxisIndexKey(
                    MultiAxisIndexKeyItem.slice(
                        Slice(
                            OptionalInt.none,
                            OptionalInt.some(
                                sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))
                                .to_value()
                                .to_int
                            ),
                        )
                    )
                )
            )
        ]
    )
)[IndexKey.multi_axis(_MultiAxisIndexKey_1 + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(1))))))]

Compiling back to Python source#

Now we finally have a version that we could run with Numba! However, this isn’t in NumPy code. What Numba needs is a function that uses numpy, not our typed dialect.

So we use another module that provides a translation of all our methods into Python strings. The rules in it look like this:

# the sqrt of an array should use the `np.sqrt` function and be assigned to its own variable, so it can be reused
rewrite(ndarray_program(sqrt(x))).to((Program("np.sqrt(") + ndarray_program(x) + ")").assign())

# To compile a setitem call, we first compile the source, assign it to a variable, then add an assignment statement
mod_x = copy(x)
mod_x[idx] = y
assigned_x = ndarray_program(x).assign()
yield rewrite(ndarray_program(mod_x)).to(
    assigned_x.statement(assigned_x + "[" + index_key_program(idx) + "] = " + ndarray_program(y))
)

We pull in all those rewrite rules from the egglog.exp.array_api_program_gen module. They depend on another module, egglog.exp.program_gen module, which provides generic translations from expressions and statements into strings.

We can run these rules to get out a Python function object:

from egglog.exp.array_api_program_gen import (
    ndarray_function_two,
    array_api_module_string,
)

egraph = EGraph([array_api_module_string])
fn_program = ndarray_function_two(X_r2_numba, X_orig, y_orig)
egraph.register(fn_program)
egraph.run(10000)
fn = egraph.load_object(egraph.extract(fn_program.py_object))
import inspect

print(inspect.getsource(fn))
def __fn(X, y):
    assert X.dtype == np.dtype(np.float64)
    assert X.shape == (1000000, 20,)
    assert np.all(np.isfinite(X))
    assert y.dtype == np.dtype(np.int64)
    assert y.shape == (1000000,)
    assert set(np.unique(y)) == set((0, 1,))
    _0 = y == np.array(0)
    _1 = np.sum(_0)
    _2 = y == np.array(1)
    _3 = np.sum(_2)
    _4 = np.array((_1, _3,)).astype(np.dtype(np.float64))
    _5 = _4 / np.array(1000000.0)
    _6 = np.zeros((2, 20,), dtype=np.dtype(np.float64))
    _7 = np.sum(X[_0], axis=0)
    _8 = _7 / np.array(X[_0].shape[0])
    _6[0, :] = _8
    _9 = np.sum(X[_2], axis=0)
    _10 = _9 / np.array(X[_2].shape[0])
    _6[1, :] = _10
    _11 = _5 @ _6
    _12 = X - _11
    _13 = np.sqrt(np.array((1.0 / 999998)))
    _14 = X[_0] - _6[0, :]
    _15 = X[_2] - _6[1, :]
    _16 = np.concatenate((_14, _15,), axis=0)
    _17 = np.sum(_16, axis=0)
    _18 = _17 / np.array(_16.shape[0])
    _19 = np.expand_dims(_18, 0)
    _20 = _16 - _19
    _21 = np.square(_20)
    _22 = np.sum(_21, axis=0)
    _23 = _22 / np.array(_21.shape[0])
    _24 = np.sqrt(_23)
    _25 = _24 == np.array(0)
    _24[_25] = np.array(1.0)
    _26 = _16 / _24
    _27 = _13 * _26
    _28 = np.linalg.svd(_27, full_matrices=False)
    _29 = _28[1] > np.array(0.0001)
    _30 = _29.astype(np.dtype(np.int32))
    _31 = np.sum(_30)
    _32 = _28[2][:_31, :] / _24
    _33 = _32.T / _28[1][:_31]
    _34 = np.array(1000000) * _5
    _35 = _34 * np.array(1.0)
    _36 = np.sqrt(_35)
    _37 = _6 - _11
    _38 = _36 * _37.T
    _39 = _38.T @ _33
    _40 = np.linalg.svd(_39, full_matrices=False)
    _41 = np.array(0.0001) * _40[1][0]
    _42 = _40[1] > _41
    _43 = _42.astype(np.dtype(np.int32))
    _44 = np.sum(_43)
    _45 = _33 @ _40[2].T[:, :_44]
    _46 = _12 @ _45
    return _46[:, :1]

We can verify that the function gives the same result:

import numpy as np

assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np))

Although it isn’t the prettiest, we can see that it has only emitted each expression once, for common subexpression elimination, and preserves the “imperative” aspects of setitem.

Compiling to Numba#

Now we finally have a function we can run with numba:

import numba
import os

fn_numba = numba.njit(fastmath=True)(fn)
assert np.allclose(run_lda(X_np, y_np), fn_numba(X_np, y_np))
/var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/egglog-9e61d62c-d17d-495b-b8db-f1eb3b38dcbb.py:56: NumbaPerformanceWarning: '@' is faster on contiguous arrays, called on (Array(float64, 2, 'C', False, aligned=True), Array(float64, 2, 'A', False, aligned=True))
  _45 = _33 @ _40[2].T[:, :_44]

Evaluating performance#

Let’s see if it actually made anything quicker! Let’s run a number of trials for the original function, our extracted version, and the optimized extracted version:

import timeit
import pandas as pd

stmts = {
    "original": "run_lda(X_np, y_np)",
    "extracted": "fn(X_np, y_np)",
    "extracted numba": "fn_numba(X_np, y_np)",
}
df = pd.DataFrame.from_dict(
    {name: timeit.repeat(stmt, globals=globals(), number=1, repeat=10) for name, stmt in stmts.items()}
)

df
original extracted extracted numba
0 1.482975 1.609354 1.086486
1 1.498656 1.504704 1.145331
2 1.500998 1.557253 1.090356
3 1.519732 1.548800 1.122623
4 1.500420 1.501195 1.113089
5 1.587211 1.522518 1.176842
6 1.499479 1.526887 1.095296
7 1.639910 1.500859 1.086477
8 1.525145 1.559202 1.103662
9 1.535601 1.474299 1.074152
import seaborn as sns

df_melt = pd.melt(df, var_name="function", value_name="time")
_ = sns.catplot(data=df_melt, x="function", y="time", kind="swarm")
../_images/0662bf9dd34caffce5ae8be1110d6b58d82419765f5373441c528a88985e8a28.png

We see that the numba version is in fact faster, and the other two are about the same. It isn’t significantly faster through, so we might want to run a profiler on the original function to see where most of the time is spent:

%load_ext line_profiler
%lprun -f fn fn(X_np, y_np)
Timer unit: 1e-09 s

Total time: 1.41607 s
File: /var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/egglog-9e61d62c-d17d-495b-b8db-f1eb3b38dcbb.py
Function: __fn at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     1                                           def __fn(X, y):
     2         1      13000.0  13000.0      0.0      assert X.dtype == np.dtype(np.float64)
     3         1       2000.0   2000.0      0.0      assert X.shape == (1000000, 20,)
     4         1   23813000.0    2e+07      1.7      assert np.all(np.isfinite(X))
     5         1      11000.0  11000.0      0.0      assert y.dtype == np.dtype(np.int64)
     6         1      14000.0  14000.0      0.0      assert y.shape == (1000000,)
     7         1   23226000.0    2e+07      1.6      assert set(np.unique(y)) == set((0, 1,))
     8         1     542000.0 542000.0      0.0      _0 = y == np.array(0)
     9         1     488000.0 488000.0      0.0      _1 = np.sum(_0)
    10         1     493000.0 493000.0      0.0      _2 = y == np.array(1)
    11         1     454000.0 454000.0      0.0      _3 = np.sum(_2)
    12         1      14000.0  14000.0      0.0      _4 = np.array((_1, _3,)).astype(np.dtype(np.float64))
    13         1       9000.0   9000.0      0.0      _5 = _4 / np.array(1000000.0)
    14         1       4000.0   4000.0      0.0      _6 = np.zeros((2, 20,), dtype=np.dtype(np.float64))
    15         1   98376000.0    1e+08      6.9      _7 = np.sum(X[_0], axis=0)
    16         1   38374000.0    4e+07      2.7      _8 = _7 / np.array(X[_0].shape[0])
    17         1       6000.0   6000.0      0.0      _6[0, :] = _8
    18         1   45697000.0    5e+07      3.2      _9 = np.sum(X[_2], axis=0)
    19         1   35522000.0    4e+07      2.5      _10 = _9 / np.array(X[_2].shape[0])
    20         1       6000.0   6000.0      0.0      _6[1, :] = _10
    21         1      13000.0  13000.0      0.0      _11 = _5 @ _6
    22         1   33768000.0    3e+07      2.4      _12 = X - _11
    23         1      18000.0  18000.0      0.0      _13 = np.sqrt(np.array((1.0 / 999998)))
    24         1   50544000.0    5e+07      3.6      _14 = X[_0] - _6[0, :]
    25         1   55966000.0    6e+07      4.0      _15 = X[_2] - _6[1, :]
    26         1   26138000.0    3e+07      1.8      _16 = np.concatenate((_14, _15,), axis=0)
    27         1   23667000.0    2e+07      1.7      _17 = np.sum(_16, axis=0)
    28         1      26000.0  26000.0      0.0      _18 = _17 / np.array(_16.shape[0])
    29         1      45000.0  45000.0      0.0      _19 = np.expand_dims(_18, 0)
    30         1   33604000.0    3e+07      2.4      _20 = _16 - _19
    31         1   24774000.0    2e+07      1.7      _21 = np.square(_20)
    32         1   21671000.0    2e+07      1.5      _22 = np.sum(_21, axis=0)
    33         1      31000.0  31000.0      0.0      _23 = _22 / np.array(_21.shape[0])
    34         1       4000.0   4000.0      0.0      _24 = np.sqrt(_23)
    35         1       7000.0   7000.0      0.0      _25 = _24 == np.array(0)
    36         1       3000.0   3000.0      0.0      _24[_25] = np.array(1.0)
    37         1   32910000.0    3e+07      2.3      _26 = _16 / _24
    38         1   24105000.0    2e+07      1.7      _27 = _13 * _26
    39         1  814200000.0    8e+08     57.5      _28 = np.linalg.svd(_27, full_matrices=False)
    40         1      23000.0  23000.0      0.0      _29 = _28[1] > np.array(0.0001)
    41         1      10000.0  10000.0      0.0      _30 = _29.astype(np.dtype(np.int32))
    42         1      63000.0  63000.0      0.0      _31 = np.sum(_30)
    43         1      14000.0  14000.0      0.0      _32 = _28[2][:_31, :] / _24
    44         1       7000.0   7000.0      0.0      _33 = _32.T / _28[1][:_31]
    45         1       9000.0   9000.0      0.0      _34 = np.array(1000000) * _5
    46         1       4000.0   4000.0      0.0      _35 = _34 * np.array(1.0)
    47         1       3000.0   3000.0      0.0      _36 = np.sqrt(_35)
    48         1       5000.0   5000.0      0.0      _37 = _6 - _11
    49         1       4000.0   4000.0      0.0      _38 = _36 * _37.T
    50         1      11000.0  11000.0      0.0      _39 = _38.T @ _33
    51         1      70000.0  70000.0      0.0      _40 = np.linalg.svd(_39, full_matrices=False)
    52         1       6000.0   6000.0      0.0      _41 = np.array(0.0001) * _40[1][0]
    53         1       3000.0   3000.0      0.0      _42 = _40[1] > _41
    54         1       4000.0   4000.0      0.0      _43 = _42.astype(np.dtype(np.int32))
    55         1      18000.0  18000.0      0.0      _44 = np.sum(_43)
    56         1       8000.0   8000.0      0.0      _45 = _33 @ _40[2].T[:, :_44]
    57         1    7242000.0    7e+06      0.5      _46 = _12 @ _45
    58         1       7000.0   7000.0      0.0      return _46[:, :1]

We see that most of the time is spent in the SVD funciton, which wouldn’t be improved much by numba since it is will call out to LAPACK, just like NumPy. The only savings would come from the other parts of the progarm, which can be inlined into

Conclusion#

To recap, in this tutorial we:

  1. Tried using a normal scikit-learn LDA function on some test data.

  2. Built up an abstract array and called it with that instead

  3. Optimized it and translated it to work with Numba

  4. Compiled it to a standalone Python funciton, which was optimized with Numba

  5. Verified that this improved our performance with this test data.

The implementation of the Array API provided here is experimental, and not complete, but at least serves to show it is possible to build an API like that with egglog.