{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "```{post} 2023-10-26\n", "\n", "```\n" ] }, { "cell_type": "markdown", "id": "34e1966a", "metadata": {}, "source": [ "# Optimizing Scikit-Learn with Array API and Numba\n", "\n", "In this tutorial, we will walk through increasing the performance of your scikit-learn code using `egglog` and [`numba`](https://numba.readthedocs.io/en/stable/user/5minguide.html).\n", "\n", "One of the goals of `egglog` is to be used by other scientific computing libraries to create flexible APIs,\n", "which conform to existing user expectations but allow a greater flexability in how they perform execution.\n", "\n", "To work towards that we goal, we have built an prototype of a [Array API standard](https://data-apis.org/array-api/2022.12/index.html) conformant API\n", "that can be used with [Scikit-Learn's experimental Array API support](https://scikit-learn.org/stable/modules/array_api.html),\n", "to optimize it using Numba.\n", "\n", "## Normal execution\n", "\n", "We can create a test data set and use `LDA` to create a classification. Then we can run it on the dataset, to\n", "return the estimated classification for out test data:\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "6b130384", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 0.64233002],\n", " [ 0.63661245],\n", " [-1.603293 ],\n", " ...,\n", " [-1.1506433 ],\n", " [ 0.71687176],\n", " [-1.51119579]])" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn import config_context\n", "from sklearn.datasets import make_classification\n", "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", "\n", "\n", "def run_lda(x, y):\n", " with config_context(array_api_dispatch=True):\n", " lda = LinearDiscriminantAnalysis()\n", " return lda.fit(x, y).transform(x)\n", "\n", "\n", "X_np, y_np = make_classification(random_state=0, n_samples=1000000)\n", "run_lda(X_np, y_np)" ] }, { "cell_type": "markdown", "id": "df1da938", "metadata": {}, "source": [ "## Building our inputs\n", "\n", "Now, we can try executing it with `egglog` instead. In this mode, we aren't actually passing in any particular\n", "NDArray, but instead just using variables to represent the X and Y values.\n", "\n", "These are defined in the [`egglog.exp.array_api` module](https://github.com/egraphs-good/egglog-python/blob/main/python/egglog/exp/array_api.py), as typed values:\n", "\n", "```python\n", "@array_api_module.class_\n", "class NDArray(Expr):\n", " @array_api_module.method(cost=200)\n", " @classmethod\n", " def var(cls, name: StringLike) -> NDArray: ...\n", "\n", " @property\n", " def shape(self) -> TupleInt: ...\n", "\n", " ...\n", "\n", "@array_api_module.function(mutates_first_arg=True)\n", "def assume_shape(x: NDArray, shape: TupleInt) -> None: ...\n", "```\n", "\n", "We can use these functon to provides some metadata about the arguments as well:\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "1c216ca4", "metadata": {}, "outputs": [], "source": [ "from copy import copy\n", "\n", "from egglog.exp.array_api import (\n", " NDArray,\n", " assume_dtype,\n", " assume_shape,\n", " assume_isfinite,\n", " assume_value_one_of,\n", ")\n", "\n", "X_arr = NDArray.var(\"X\")\n", "X_orig = copy(X_arr)\n", "\n", "assume_dtype(X_arr, X_np.dtype)\n", "assume_shape(X_arr, X_np.shape)\n", "assume_isfinite(X_arr)\n", "\n", "y_arr = NDArray.var(\"y\")\n", "y_orig = copy(y_arr)\n", "\n", "assume_dtype(y_arr, y_np.dtype)\n", "assume_shape(y_arr, y_np.shape)\n", "assume_value_one_of(y_arr, (0, 1))" ] }, { "cell_type": "markdown", "id": "a184f6e6", "metadata": {}, "source": [ "While most of the execution can be deferred, every time sklearn triggers Python control flow (`if`, `for`, etc), we\n", "need to execute eagerly and be able to give a definate value. For example, scikit-learn checks to makes sure that the\n", "number of samples we pass in is greater than the number of unique classes:\n", "\n", "```python\n", "class LinearDiscriminantAnalysis(...):\n", " ...\n", " def fit(self, X, y):\n", " ...\n", " self.classes_ = unique_labels(y)\n", " n_samples, _ = X.shape\n", " n_classes = self.classes_.shape[0]\n", "\n", " if n_samples == n_classes:\n", " raise ValueError(\n", " \"The number of samples must be more than the number of classes.\"\n", " )\n", " ...\n", "```\n", "\n", "Without the assumptions above, we wouldn't know if the conditional is true or false. So we provide just enough information\n", "for sklearn to finish executing and give us a result.\n", "\n", "## Getting a result\n", "\n", "We can now run our lda function with our inputs, which have the constraints about them saved, and see the graph which\n", "will show all of the intermerdiate results we had to compute to get our answer:\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "3d7291d8", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
_NDArray_1 = NDArray.var("X")\n",
              "assume_dtype(_NDArray_1, DType.float64)\n",
              "assume_shape(_NDArray_1, TupleInt(Int(1000000)) + TupleInt(Int(20)))\n",
              "assume_isfinite(_NDArray_1)\n",
              "_NDArray_2 = NDArray.var("y")\n",
              "assume_dtype(_NDArray_2, DType.int64)\n",
              "assume_shape(_NDArray_2, TupleInt(Int(1000000)))\n",
              "assume_value_one_of(_NDArray_2, TupleValue(Value.int(Int(0))) + TupleValue(Value.int(Int(1))))\n",
              "_NDArray_3 = asarray(reshape(asarray(_NDArray_2), TupleInt(Int(-1))))\n",
              "_NDArray_4 = astype(unique_counts(_NDArray_3)[Int(1)], asarray(_NDArray_1).dtype) / NDArray.scalar(Value.float(Float(1000000.0)))\n",
              "_NDArray_5 = zeros(\n",
              "    TupleInt(unique_inverse(_NDArray_3)[Int(0)].shape[Int(0)]) + TupleInt(asarray(_NDArray_1).shape[Int(1)]),\n",
              "    OptionalDType.some(asarray(_NDArray_1).dtype),\n",
              "    OptionalDevice.some(asarray(_NDArray_1).device),\n",
              ")\n",
              "_MultiAxisIndexKey_1 = MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice()))\n",
              "_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + _MultiAxisIndexKey_1)\n",
              "_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))\n",
              "_NDArray_5[_IndexKey_1] = mean(asarray(_NDArray_1)[ndarray_index(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(0))))], _OptionalIntOrTuple_1)\n",
              "_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(1))) + _MultiAxisIndexKey_1)\n",
              "_NDArray_5[_IndexKey_2] = mean(asarray(_NDArray_1)[ndarray_index(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(1))))], _OptionalIntOrTuple_1)\n",
              "_NDArray_6 = unique_values(concat(TupleNDArray(unique_values(asarray(_NDArray_3)))))\n",
              "_NDArray_7 = concat(\n",
              "    TupleNDArray(asarray(_NDArray_1)[ndarray_index(_NDArray_3 == _NDArray_6[IndexKey.int(Int(0))])] - _NDArray_5[_IndexKey_1])\n",
              "    + TupleNDArray(asarray(_NDArray_1)[ndarray_index(_NDArray_3 == _NDArray_6[IndexKey.int(Int(1))])] - _NDArray_5[_IndexKey_2]),\n",
              "    OptionalInt.some(Int(0)),\n",
              ")\n",
              "_NDArray_8 = std(_NDArray_7, _OptionalIntOrTuple_1)\n",
              "_NDArray_8[ndarray_index(std(_NDArray_7, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))\n",
              "_TupleNDArray_1 = svd(\n",
              "    sqrt(asarray(NDArray.scalar(Value.float(Float(1.0) / Float.from_int(asarray(_NDArray_1).shape[Int(0)] - _NDArray_6.shape[Int(0)]))))) * (_NDArray_7 / _NDArray_8), FALSE\n",
              ")\n",
              "_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))\n",
              "_NDArray_9 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_8).T / _TupleNDArray_1[\n",
              "    Int(1)\n",
              "][IndexKey.slice(_Slice_1)]\n",
              "_TupleNDArray_2 = svd(\n",
              "    (\n",
              "        sqrt(\n",
              "            (NDArray.scalar(Value.int(asarray(_NDArray_1).shape[Int(0)])) * _NDArray_4)\n",
              "            * NDArray.scalar(Value.float(Float(1.0) / Float.from_int(_NDArray_6.shape[Int(0)] - Int(1))))\n",
              "        )\n",
              "        * (_NDArray_5 - (_NDArray_4 @ _NDArray_5)).T\n",
              "    ).T\n",
              "    @ _NDArray_9,\n",
              "    FALSE,\n",
              ")\n",
              "(\n",
              "    (asarray(_NDArray_1) - (_NDArray_4 @ _NDArray_5))\n",
              "    @ (\n",
              "        _NDArray_9\n",
              "        @ _TupleNDArray_2[Int(2)].T[\n",
              "            IndexKey.multi_axis(\n",
              "                _MultiAxisIndexKey_1\n",
              "                + MultiAxisIndexKey(\n",
              "                    MultiAxisIndexKeyItem.slice(\n",
              "                        Slice(\n",
              "                            OptionalInt.none,\n",
              "                            OptionalInt.some(\n",
              "                                sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))\n",
              "                                .to_value()\n",
              "                                .to_int\n",
              "                            ),\n",
              "                        )\n",
              "                    )\n",
              "                )\n",
              "            )\n",
              "        ]\n",
              "    )\n",
              ")[IndexKey.multi_axis(_MultiAxisIndexKey_1 + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(_NDArray_6.shape[Int(0)] - Int(1))))))]\n",
              "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}1} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{X}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}dtype}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}shape}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1000000}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{20}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}isfinite}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{y}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}dtype}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int64}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}shape}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1000000}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}value\\PYZus{}one\\PYZus{}of}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{TupleValue}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleValue}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{=} \\PY{n}{asarray}\\PY{p}{(}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{)}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{=} \\PY{n}{astype}\\PY{p}{(}\\PY{n}{unique\\PYZus{}counts}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{,} \\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{dtype}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1000000.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}5} \\PY{o}{=} \\PY{n}{zeros}\\PY{p}{(}\n", " \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{unique\\PYZus{}inverse}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,}\n", " \\PY{n}{OptionalDType}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{dtype}\\PY{p}{)}\\PY{p}{,}\n", " \\PY{n}{OptionalDevice}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{device}\\PY{p}{)}\\PY{p}{,}\n", "\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1} \\PY{o}{=} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{Slice}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1} \\PY{o}{=} \\PY{n}{OptionalIntOrTuple}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{IntOrTuple}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1}\\PY{p}{]} \\PY{o}{=} \\PY{n}{mean}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{unique\\PYZus{}inverse}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2}\\PY{p}{]} \\PY{o}{=} \\PY{n}{mean}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{unique\\PYZus{}inverse}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}6} \\PY{o}{=} \\PY{n}{unique\\PYZus{}values}\\PY{p}{(}\\PY{n}{concat}\\PY{p}{(}\\PY{n}{TupleNDArray}\\PY{p}{(}\\PY{n}{unique\\PYZus{}values}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{=} \\PY{n}{concat}\\PY{p}{(}\n", " \\PY{n}{TupleNDArray}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{==} \\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{o}{+} \\PY{n}{TupleNDArray}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{==} \\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,}\n", " \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,}\n", "\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}8} \\PY{o}{=} \\PY{n}{std}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}8}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{std}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1} \\PY{o}{=} \\PY{n}{svd}\\PY{p}{(}\n", " \\PY{n}{sqrt}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)} \\PY{o}{/} \\PY{n}{Float}\\PY{o}{.}\\PY{n}{from\\PYZus{}int}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}8}\\PY{p}{)}\\PY{p}{,} \\PY{n}{FALSE}\n", "\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}Slice\\PYZus{}1} \\PY{o}{=} \\PY{n}{Slice}\\PY{p}{(}\\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZgt{}} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{0.0001}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int32}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}int}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}9} \\PY{o}{=} \\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}8}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T} \\PY{o}{/} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\n", " \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\n", "\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{]}\n", "\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2} \\PY{o}{=} \\PY{n}{svd}\\PY{p}{(}\n", " \\PY{p}{(}\n", " \\PY{n}{sqrt}\\PY{p}{(}\n", " \\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{)}\n", " \\PY{o}{*} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)} \\PY{o}{/} \\PY{n}{Float}\\PY{o}{.}\\PY{n}{from\\PYZus{}int}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", " \\PY{p}{)}\n", " \\PY{o}{*} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5} \\PY{o}{\\PYZhy{}} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T}\n", " \\PY{p}{)}\\PY{o}{.}\\PY{n}{T}\n", " \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}9}\\PY{p}{,}\n", " \\PY{n}{FALSE}\\PY{p}{,}\n", "\\PY{p}{)}\n", "\\PY{p}{(}\n", " \\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)} \\PY{o}{\\PYZhy{}} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{)}\\PY{p}{)}\n", " \\PY{o}{@} \\PY{p}{(}\n", " \\PY{n}{\\PYZus{}NDArray\\PYZus{}9}\n", " \\PY{o}{@} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{o}{.}\\PY{n}{T}\\PY{p}{[}\n", " \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\n", " \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\n", " \\PY{o}{+} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\n", " \\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\n", " \\PY{n}{Slice}\\PY{p}{(}\n", " \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,}\n", " \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\n", " \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZgt{}} \\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{0.0001}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int32}\\PY{p}{)}\\PY{p}{)}\n", " \\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{o}{.}\\PY{n}{to\\PYZus{}int}\n", " \\PY{p}{)}\\PY{p}{,}\n", " \\PY{p}{)}\n", " \\PY{p}{)}\n", " \\PY{p}{)}\n", " \\PY{p}{)}\n", " \\PY{p}{]}\n", " \\PY{p}{)}\n", "\\PY{p}{)}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1} \\PY{o}{+} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{Slice}\\PY{p}{(}\\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\n", "\\end{Verbatim}\n" ], "text/plain": [ "_NDArray_1 = NDArray.var(\"X\")\n", "assume_dtype(_NDArray_1, DType.float64)\n", "assume_shape(_NDArray_1, TupleInt(Int(1000000)) + TupleInt(Int(20)))\n", "assume_isfinite(_NDArray_1)\n", "_NDArray_2 = NDArray.var(\"y\")\n", "assume_dtype(_NDArray_2, DType.int64)\n", "assume_shape(_NDArray_2, TupleInt(Int(1000000)))\n", "assume_value_one_of(_NDArray_2, TupleValue(Value.int(Int(0))) + TupleValue(Value.int(Int(1))))\n", "_NDArray_3 = asarray(reshape(asarray(_NDArray_2), TupleInt(Int(-1))))\n", "_NDArray_4 = astype(unique_counts(_NDArray_3)[Int(1)], asarray(_NDArray_1).dtype) / NDArray.scalar(Value.float(Float(1000000.0)))\n", "_NDArray_5 = zeros(\n", " TupleInt(unique_inverse(_NDArray_3)[Int(0)].shape[Int(0)]) + TupleInt(asarray(_NDArray_1).shape[Int(1)]),\n", " OptionalDType.some(asarray(_NDArray_1).dtype),\n", " OptionalDevice.some(asarray(_NDArray_1).device),\n", ")\n", "_MultiAxisIndexKey_1 = MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice()))\n", "_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + _MultiAxisIndexKey_1)\n", "_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))\n", "_NDArray_5[_IndexKey_1] = mean(asarray(_NDArray_1)[ndarray_index(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(0))))], _OptionalIntOrTuple_1)\n", "_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(1))) + _MultiAxisIndexKey_1)\n", "_NDArray_5[_IndexKey_2] = mean(asarray(_NDArray_1)[ndarray_index(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(1))))], _OptionalIntOrTuple_1)\n", "_NDArray_6 = unique_values(concat(TupleNDArray(unique_values(asarray(_NDArray_3)))))\n", "_NDArray_7 = concat(\n", " TupleNDArray(asarray(_NDArray_1)[ndarray_index(_NDArray_3 == _NDArray_6[IndexKey.int(Int(0))])] - _NDArray_5[_IndexKey_1])\n", " + TupleNDArray(asarray(_NDArray_1)[ndarray_index(_NDArray_3 == _NDArray_6[IndexKey.int(Int(1))])] - _NDArray_5[_IndexKey_2]),\n", " OptionalInt.some(Int(0)),\n", ")\n", "_NDArray_8 = std(_NDArray_7, _OptionalIntOrTuple_1)\n", "_NDArray_8[ndarray_index(std(_NDArray_7, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))\n", "_TupleNDArray_1 = svd(\n", " sqrt(asarray(NDArray.scalar(Value.float(Float(1.0) / Float.from_int(asarray(_NDArray_1).shape[Int(0)] - _NDArray_6.shape[Int(0)]))))) * (_NDArray_7 / _NDArray_8), FALSE\n", ")\n", "_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))\n", "_NDArray_9 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_8).T / _TupleNDArray_1[\n", " Int(1)\n", "][IndexKey.slice(_Slice_1)]\n", "_TupleNDArray_2 = svd(\n", " (\n", " sqrt(\n", " (NDArray.scalar(Value.int(asarray(_NDArray_1).shape[Int(0)])) * _NDArray_4)\n", " * NDArray.scalar(Value.float(Float(1.0) / Float.from_int(_NDArray_6.shape[Int(0)] - Int(1))))\n", " )\n", " * (_NDArray_5 - (_NDArray_4 @ _NDArray_5)).T\n", " ).T\n", " @ _NDArray_9,\n", " FALSE,\n", ")\n", "(\n", " (asarray(_NDArray_1) - (_NDArray_4 @ _NDArray_5))\n", " @ (\n", " _NDArray_9\n", " @ _TupleNDArray_2[Int(2)].T[\n", " IndexKey.multi_axis(\n", " _MultiAxisIndexKey_1\n", " + MultiAxisIndexKey(\n", " MultiAxisIndexKeyItem.slice(\n", " Slice(\n", " OptionalInt.none,\n", " OptionalInt.some(\n", " sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))\n", " .to_value()\n", " .to_int\n", " ),\n", " )\n", " )\n", " )\n", " )\n", " ]\n", " )\n", ")[IndexKey.multi_axis(_MultiAxisIndexKey_1 + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(_NDArray_6.shape[Int(0)] - Int(1))))))]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from egglog import EGraph\n", "from egglog.exp.array_api import array_api_module\n", "\n", "with EGraph([array_api_module]) as egraph:\n", " X_r2 = run_lda(X_arr, y_arr)\n", " egraph.display(n_inline_leaves=3, split_primitive_outputs=True)\n", "X_r2" ] }, { "cell_type": "markdown", "id": "580da17b", "metadata": {}, "source": [ "We now have extracted out a program which is semantically equivalent to the original call! One thing you might notice\n", "is that the expression has more types than customary NumPy code. Every object is lifted into a strongly typed `egglog`\n", "class. This is so that when we run optimizations, we know the types of all the objects. It still is compatible with\n", "normal Python objects, but they are [converted](type-promotion) when they are passed as argument.\n", "\n", "## Optimizing our result\n", "\n", "Now that we have the an expression, we can run our rewrite rules to \"optimize\" it, extracting out the lowest cost\n", "(smallest) expression afterword:\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "4d3cd4f3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
_NDArray_1 = NDArray.var("X")\n",
              "assume_dtype(_NDArray_1, DType.float64)\n",
              "assume_shape(_NDArray_1, TupleInt(Int(1000000)) + TupleInt(Int(20)))\n",
              "assume_isfinite(_NDArray_1)\n",
              "_NDArray_2 = NDArray.var("y")\n",
              "assume_dtype(_NDArray_2, DType.int64)\n",
              "assume_shape(_NDArray_2, TupleInt(Int(1000000)))\n",
              "assume_value_one_of(_NDArray_2, TupleValue(Value.int(Int(0))) + TupleValue(Value.int(Int(1))))\n",
              "_NDArray_3 = astype(unique_counts(_NDArray_2)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(1000000.0)))\n",
              "_NDArray_4 = zeros(TupleInt(Int(2)) + TupleInt(Int(20)), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))\n",
              "_MultiAxisIndexKey_1 = MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice()))\n",
              "_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + _MultiAxisIndexKey_1)\n",
              "_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))\n",
              "_NDArray_4[_IndexKey_1] = mean(_NDArray_1[ndarray_index(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(0))))], _OptionalIntOrTuple_1)\n",
              "_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(1))) + _MultiAxisIndexKey_1)\n",
              "_NDArray_4[_IndexKey_2] = mean(_NDArray_1[ndarray_index(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(1))))], _OptionalIntOrTuple_1)\n",
              "_NDArray_5 = concat(\n",
              "    TupleNDArray(_NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(0))))] - _NDArray_4[_IndexKey_1])\n",
              "    + TupleNDArray(_NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(1))))] - _NDArray_4[_IndexKey_2]),\n",
              "    OptionalInt.some(Int(0)),\n",
              ")\n",
              "_NDArray_6 = std(_NDArray_5, _OptionalIntOrTuple_1)\n",
              "_NDArray_6[ndarray_index(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))\n",
              "_TupleNDArray_1 = svd(sqrt(NDArray.scalar(Value.float(Float(1.0) / Float.from_int(Int(999998))))) * (_NDArray_5 / _NDArray_6), FALSE)\n",
              "_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))\n",
              "_NDArray_7 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_6).T / _TupleNDArray_1[\n",
              "    Int(1)\n",
              "][IndexKey.slice(_Slice_1)]\n",
              "_TupleNDArray_2 = svd(\n",
              "    (sqrt((NDArray.scalar(Value.int(Int(1000000))) * _NDArray_3) * NDArray.scalar(Value.float(Float(1.0)))) * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T).T @ _NDArray_7, FALSE\n",
              ")\n",
              "(\n",
              "    (_NDArray_1 - (_NDArray_3 @ _NDArray_4))\n",
              "    @ (\n",
              "        _NDArray_7\n",
              "        @ _TupleNDArray_2[Int(2)].T[\n",
              "            IndexKey.multi_axis(\n",
              "                _MultiAxisIndexKey_1\n",
              "                + MultiAxisIndexKey(\n",
              "                    MultiAxisIndexKeyItem.slice(\n",
              "                        Slice(\n",
              "                            OptionalInt.none,\n",
              "                            OptionalInt.some(\n",
              "                                sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))\n",
              "                                .to_value()\n",
              "                                .to_int\n",
              "                            ),\n",
              "                        )\n",
              "                    )\n",
              "                )\n",
              "            )\n",
              "        ]\n",
              "    )\n",
              ")[IndexKey.multi_axis(_MultiAxisIndexKey_1 + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(1))))))]\n",
              "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}1} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{X}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}dtype}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}shape}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1000000}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{20}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}isfinite}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{y}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}dtype}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int64}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}shape}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1000000}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}value\\PYZus{}one\\PYZus{}of}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{TupleValue}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleValue}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{=} \\PY{n}{astype}\\PY{p}{(}\\PY{n}{unique\\PYZus{}counts}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1000000.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{=} \\PY{n}{zeros}\\PY{p}{(}\\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{20}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{OptionalDType}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)}\\PY{p}{,} \\PY{n}{OptionalDevice}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{o}{.}\\PY{n}{device}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1} \\PY{o}{=} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{Slice}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1} \\PY{o}{=} \\PY{n}{OptionalIntOrTuple}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{IntOrTuple}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1}\\PY{p}{]} \\PY{o}{=} \\PY{n}{mean}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{unique\\PYZus{}inverse}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2}\\PY{p}{]} \\PY{o}{=} \\PY{n}{mean}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{unique\\PYZus{}inverse}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}5} \\PY{o}{=} \\PY{n}{concat}\\PY{p}{(}\n", " \\PY{n}{TupleNDArray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{o}{+} \\PY{n}{TupleNDArray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,}\n", " \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,}\n", "\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}6} \\PY{o}{=} \\PY{n}{std}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{std}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1} \\PY{o}{=} \\PY{n}{svd}\\PY{p}{(}\\PY{n}{sqrt}\\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)} \\PY{o}{/} \\PY{n}{Float}\\PY{o}{.}\\PY{n}{from\\PYZus{}int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{999998}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5} \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{p}{)}\\PY{p}{,} \\PY{n}{FALSE}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}Slice\\PYZus{}1} \\PY{o}{=} \\PY{n}{Slice}\\PY{p}{(}\\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZgt{}} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{0.0001}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int32}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}int}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{=} \\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T} \\PY{o}{/} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\n", " \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\n", "\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{]}\n", "\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2} \\PY{o}{=} \\PY{n}{svd}\\PY{p}{(}\n", " \\PY{p}{(}\\PY{n}{sqrt}\\PY{p}{(}\\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1000000}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{)} \\PY{o}{*} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{\\PYZhy{}} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}7}\\PY{p}{,} \\PY{n}{FALSE}\n", "\\PY{p}{)}\n", "\\PY{p}{(}\n", " \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1} \\PY{o}{\\PYZhy{}} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{)}\\PY{p}{)}\n", " \\PY{o}{@} \\PY{p}{(}\n", " \\PY{n}{\\PYZus{}NDArray\\PYZus{}7}\n", " \\PY{o}{@} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{o}{.}\\PY{n}{T}\\PY{p}{[}\n", " \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\n", " \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\n", " \\PY{o}{+} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\n", " \\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\n", " \\PY{n}{Slice}\\PY{p}{(}\n", " \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,}\n", " \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\n", " \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZgt{}} \\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{0.0001}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int32}\\PY{p}{)}\\PY{p}{)}\n", " \\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{o}{.}\\PY{n}{to\\PYZus{}int}\n", " \\PY{p}{)}\\PY{p}{,}\n", " \\PY{p}{)}\n", " \\PY{p}{)}\n", " \\PY{p}{)}\n", " \\PY{p}{)}\n", " \\PY{p}{]}\n", " \\PY{p}{)}\n", "\\PY{p}{)}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1} \\PY{o}{+} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{Slice}\\PY{p}{(}\\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\n", "\\end{Verbatim}\n" ], "text/plain": [ "_NDArray_1 = NDArray.var(\"X\")\n", "assume_dtype(_NDArray_1, DType.float64)\n", "assume_shape(_NDArray_1, TupleInt(Int(1000000)) + TupleInt(Int(20)))\n", "assume_isfinite(_NDArray_1)\n", "_NDArray_2 = NDArray.var(\"y\")\n", "assume_dtype(_NDArray_2, DType.int64)\n", "assume_shape(_NDArray_2, TupleInt(Int(1000000)))\n", "assume_value_one_of(_NDArray_2, TupleValue(Value.int(Int(0))) + TupleValue(Value.int(Int(1))))\n", "_NDArray_3 = astype(unique_counts(_NDArray_2)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(1000000.0)))\n", "_NDArray_4 = zeros(TupleInt(Int(2)) + TupleInt(Int(20)), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))\n", "_MultiAxisIndexKey_1 = MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice()))\n", "_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + _MultiAxisIndexKey_1)\n", "_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))\n", "_NDArray_4[_IndexKey_1] = mean(_NDArray_1[ndarray_index(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(0))))], _OptionalIntOrTuple_1)\n", "_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(1))) + _MultiAxisIndexKey_1)\n", "_NDArray_4[_IndexKey_2] = mean(_NDArray_1[ndarray_index(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(1))))], _OptionalIntOrTuple_1)\n", "_NDArray_5 = concat(\n", " TupleNDArray(_NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(0))))] - _NDArray_4[_IndexKey_1])\n", " + TupleNDArray(_NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(1))))] - _NDArray_4[_IndexKey_2]),\n", " OptionalInt.some(Int(0)),\n", ")\n", "_NDArray_6 = std(_NDArray_5, _OptionalIntOrTuple_1)\n", "_NDArray_6[ndarray_index(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))\n", "_TupleNDArray_1 = svd(sqrt(NDArray.scalar(Value.float(Float(1.0) / Float.from_int(Int(999998))))) * (_NDArray_5 / _NDArray_6), FALSE)\n", "_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))\n", "_NDArray_7 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_6).T / _TupleNDArray_1[\n", " Int(1)\n", "][IndexKey.slice(_Slice_1)]\n", "_TupleNDArray_2 = svd(\n", " (sqrt((NDArray.scalar(Value.int(Int(1000000))) * _NDArray_3) * NDArray.scalar(Value.float(Float(1.0)))) * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T).T @ _NDArray_7, FALSE\n", ")\n", "(\n", " (_NDArray_1 - (_NDArray_3 @ _NDArray_4))\n", " @ (\n", " _NDArray_7\n", " @ _TupleNDArray_2[Int(2)].T[\n", " IndexKey.multi_axis(\n", " _MultiAxisIndexKey_1\n", " + MultiAxisIndexKey(\n", " MultiAxisIndexKeyItem.slice(\n", " Slice(\n", " OptionalInt.none,\n", " OptionalInt.some(\n", " sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))\n", " .to_value()\n", " .to_int\n", " ),\n", " )\n", " )\n", " )\n", " )\n", " ]\n", " )\n", ")[IndexKey.multi_axis(_MultiAxisIndexKey_1 + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(1))))))]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "egraph = EGraph([array_api_module])\n", "egraph.register(X_r2)\n", "egraph.run(10000)\n", "X_r2_optimized = egraph.extract(X_r2)\n", "X_r2_optimized" ] }, { "cell_type": "markdown", "id": "30ea4ea4", "metadata": {}, "source": [ "We see that for example expressions that referenced the shape of our input arrays have been resolved to their\n", "values.\n", "\n", "We can also take a look at the e-graph itself, even though it's quite large, where we can see that equivalent\n", "expressions show up in the same group, or \"e-class\":\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "6417b9e5", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "outer_cluster_32\n", "\n", "\n", "cluster_32\n", "\n", "\n", "\n", "outer_cluster_48\n", "\n", "\n", "cluster_48\n", "\n", "\n", "\n", "outer_cluster_92\n", "\n", "\n", "cluster_92\n", "\n", "\n", "\n", "outer_cluster_155\n", "\n", "\n", "cluster_155\n", "\n", "\n", "\n", "outer_cluster_34\n", "\n", "\n", "cluster_34\n", "\n", "\n", "\n", "outer_cluster_91\n", "\n", "\n", "cluster_91\n", "\n", "\n", "\n", "outer_cluster_72\n", "\n", "\n", "cluster_72\n", "\n", "\n", "\n", "outer_cluster_190\n", "\n", "\n", "cluster_190\n", "\n", "\n", "\n", "outer_cluster_147\n", "\n", "\n", "cluster_147\n", "\n", "\n", "\n", "outer_cluster_58\n", "\n", "\n", "cluster_58\n", "\n", "\n", "\n", "outer_cluster_62\n", "\n", "\n", "cluster_62\n", "\n", "\n", "\n", "outer_cluster_143\n", "\n", "\n", "cluster_143\n", "\n", "\n", "\n", "outer_cluster_75\n", "\n", "\n", "cluster_75\n", "\n", "\n", "\n", "outer_cluster_118\n", "\n", "\n", "cluster_118\n", "\n", "\n", "\n", "outer_cluster_105\n", "\n", "\n", "cluster_105\n", "\n", "\n", "\n", "outer_cluster_71\n", "\n", "\n", "cluster_71\n", "\n", "\n", "\n", "outer_cluster_141\n", "\n", "\n", "cluster_141\n", "\n", "\n", "\n", "outer_cluster_142\n", "\n", "\n", "cluster_142\n", "\n", "\n", "\n", "outer_cluster_52\n", "\n", "\n", "cluster_52\n", "\n", "\n", "\n", "outer_cluster_56\n", "\n", "\n", "cluster_56\n", "\n", "\n", "\n", "outer_cluster_188\n", "\n", "\n", "cluster_188\n", "\n", "\n", "\n", "outer_cluster_69\n", "\n", "\n", "cluster_69\n", "\n", "\n", "\n", "outer_cluster_178\n", "\n", "\n", "cluster_178\n", "\n", "\n", "\n", "outer_cluster_140\n", "\n", "\n", "cluster_140\n", "\n", "\n", "\n", "outer_cluster_187\n", "\n", "\n", "cluster_187\n", "\n", "\n", "\n", "outer_cluster_95\n", "\n", "\n", "cluster_95\n", "\n", "\n", "\n", "outer_cluster_12\n", "\n", "\n", "cluster_12\n", "\n", "\n", "\n", "outer_cluster_102\n", "\n", "\n", "cluster_102\n", "\n", "\n", "\n", "outer_cluster_67\n", "\n", "\n", "cluster_67\n", "\n", "\n", "\n", "outer_cluster_115\n", "\n", "\n", "cluster_115\n", "\n", "\n", "\n", "outer_cluster_63\n", "\n", "\n", "cluster_63\n", "\n", "\n", "\n", "outer_cluster_127\n", "\n", "\n", "cluster_127\n", "\n", "\n", "\n", "outer_cluster_146\n", "\n", "\n", "cluster_146\n", "\n", "\n", "\n", "outer_cluster_109\n", "\n", "\n", "cluster_109\n", "\n", "\n", "\n", "outer_cluster_191\n", "\n", "\n", "cluster_191\n", "\n", "\n", "\n", "outer_cluster_37\n", "\n", "\n", "cluster_37\n", "\n", "\n", "\n", "outer_cluster_79\n", "\n", "\n", "cluster_79\n", "\n", "\n", "\n", "outer_cluster_122\n", "\n", "\n", "cluster_122\n", "\n", "\n", "\n", "outer_cluster_163\n", "\n", "\n", "cluster_163\n", "\n", "\n", "\n", "outer_cluster_145\n", "\n", "\n", "cluster_145\n", "\n", "\n", "\n", "outer_cluster_106\n", "\n", "\n", "cluster_106\n", "\n", "\n", "\n", "outer_cluster_2\n", "\n", "\n", "cluster_2\n", "\n", "\n", "\n", "outer_cluster_149\n", "\n", "\n", "cluster_149\n", "\n", "\n", "\n", "outer_cluster_126\n", "\n", "\n", "cluster_126\n", "\n", "\n", "\n", "outer_cluster_110\n", "\n", "\n", "cluster_110\n", "\n", "\n", "\n", "outer_cluster_131\n", "\n", "\n", "cluster_131\n", "\n", "\n", "\n", "outer_cluster_171\n", "\n", "\n", "cluster_171\n", "\n", "\n", "\n", "outer_cluster_99\n", "\n", "\n", "cluster_99\n", "\n", "\n", "\n", "outer_cluster_135\n", "\n", "\n", "cluster_135\n", "\n", "\n", "\n", "outer_cluster_182\n", "\n", "\n", "cluster_182\n", "\n", "\n", "\n", "outer_cluster_8\n", "\n", "\n", "cluster_8\n", "\n", "\n", "\n", "outer_cluster_78\n", "\n", "\n", "cluster_78\n", "\n", "\n", "\n", "outer_cluster_167\n", "\n", "\n", "cluster_167\n", "\n", "\n", "\n", "outer_cluster_172\n", "\n", "\n", "cluster_172\n", "\n", "\n", "\n", "outer_cluster_169\n", "\n", "\n", "cluster_169\n", "\n", "\n", "\n", "outer_cluster_36\n", "\n", "\n", "cluster_36\n", "\n", "\n", "\n", "outer_cluster_76\n", "\n", "\n", "cluster_76\n", "\n", "\n", "\n", "outer_cluster_61\n", "\n", "\n", "cluster_61\n", "\n", "\n", "\n", "outer_cluster_157\n", "\n", "\n", "cluster_157\n", "\n", "\n", "\n", "outer_cluster_173\n", "\n", "\n", "cluster_173\n", "\n", "\n", "\n", "outer_cluster_164\n", "\n", "\n", "cluster_164\n", "\n", "\n", "\n", "outer_cluster_33\n", "\n", "\n", "cluster_33\n", "\n", "\n", "\n", "outer_cluster_83\n", "\n", "\n", "cluster_83\n", "\n", "\n", "\n", "outer_cluster_116\n", "\n", "\n", "cluster_116\n", "\n", "\n", "\n", "outer_cluster_168\n", "\n", "\n", "cluster_168\n", "\n", "\n", "\n", "outer_cluster_100\n", "\n", "\n", "cluster_100\n", "\n", "\n", "\n", "outer_cluster_108\n", "\n", "\n", "cluster_108\n", "\n", "\n", "\n", "outer_cluster_181\n", "\n", "\n", "cluster_181\n", "\n", "\n", "\n", "outer_cluster_90\n", "\n", "\n", "cluster_90\n", "\n", "\n", "\n", "outer_cluster_153\n", "\n", "\n", "cluster_153\n", "\n", "\n", "\n", "outer_cluster_175\n", "\n", "\n", "cluster_175\n", "\n", "\n", "\n", "outer_cluster_137\n", "\n", "\n", "cluster_137\n", "\n", "\n", "\n", "outer_cluster_89\n", "\n", "\n", "cluster_89\n", "\n", "\n", "\n", "outer_cluster_82\n", "\n", "\n", "cluster_82\n", "\n", "\n", "\n", "outer_cluster_44\n", "\n", "\n", "cluster_44\n", "\n", "\n", "\n", "outer_cluster_180\n", "\n", "\n", "cluster_180\n", "\n", "\n", "\n", "outer_cluster_189\n", "\n", "\n", "cluster_189\n", "\n", "\n", "\n", "outer_cluster_179\n", "\n", "\n", "cluster_179\n", "\n", "\n", "\n", "outer_cluster_70\n", "\n", "\n", "cluster_70\n", "\n", "\n", "\n", "outer_cluster_57\n", "\n", "\n", "cluster_57\n", "\n", "\n", "\n", "outer_cluster_121\n", "\n", "\n", "cluster_121\n", "\n", "\n", "\n", "outer_cluster_31\n", "\n", "\n", "cluster_31\n", "\n", "\n", "\n", "outer_cluster_123\n", "\n", "\n", "cluster_123\n", "\n", "\n", "\n", "outer_cluster_68\n", "\n", "\n", "cluster_68\n", "\n", "\n", "\n", "outer_cluster_50\n", "\n", "\n", "cluster_50\n", "\n", "\n", "\n", "outer_cluster_151\n", "\n", "\n", "cluster_151\n", "\n", "\n", "\n", "outer_cluster_77\n", "\n", "\n", "cluster_77\n", "\n", "\n", "\n", "outer_cluster_74\n", "\n", "\n", "cluster_74\n", "\n", "\n", "\n", "outer_cluster_101\n", "\n", "\n", "cluster_101\n", "\n", "\n", "\n", "outer_cluster_152\n", "\n", "\n", "cluster_152\n", "\n", "\n", "\n", "outer_cluster_130\n", "\n", "\n", "cluster_130\n", "\n", "\n", "\n", "outer_cluster_161\n", "\n", "\n", "cluster_161\n", "\n", "\n", "\n", "outer_cluster_148\n", "\n", "\n", "cluster_148\n", "\n", "\n", "\n", "outer_cluster_80\n", "\n", "\n", "cluster_80\n", "\n", "\n", "\n", "outer_cluster_111\n", "\n", "\n", "cluster_111\n", "\n", "\n", "\n", "outer_cluster_183\n", "\n", "\n", "cluster_183\n", "\n", "\n", "\n", "outer_cluster_59\n", "\n", "\n", "cluster_59\n", "\n", "\n", "\n", "outer_cluster_159\n", "\n", "\n", "cluster_159\n", "\n", "\n", "\n", "outer_cluster_103\n", "\n", "\n", "cluster_103\n", "\n", "\n", "\n", "outer_cluster_117\n", "\n", "\n", "cluster_117\n", "\n", "\n", "\n", "outer_cluster_158\n", "\n", "\n", "cluster_158\n", "\n", "\n", "\n", "outer_cluster_87\n", "\n", "\n", "cluster_87\n", "\n", "\n", "\n", "outer_cluster_96\n", "\n", "\n", "cluster_96\n", "\n", "\n", "\n", "outer_cluster_98\n", "\n", "\n", "cluster_98\n", "\n", "\n", "\n", "outer_cluster_166\n", "\n", "\n", "cluster_166\n", "\n", "\n", "\n", "outer_cluster_133\n", "\n", "\n", "cluster_133\n", "\n", "\n", "\n", "outer_cluster_184\n", "\n", "\n", "cluster_184\n", "\n", "\n", "\n", "outer_cluster_107\n", "\n", "\n", "cluster_107\n", "\n", "\n", "\n", "outer_cluster_162\n", "\n", "\n", "cluster_162\n", "\n", "\n", "\n", "outer_cluster_144\n", "\n", "\n", "cluster_144\n", "\n", "\n", "\n", "outer_cluster_170\n", "\n", "\n", "cluster_170\n", "\n", "\n", "\n", "outer_cluster_160\n", "\n", "\n", "cluster_160\n", "\n", "\n", "\n", "outer_cluster_16\n", "\n", "\n", "cluster_16\n", "\n", "\n", "\n", "outer_cluster_47\n", "\n", "\n", "cluster_47\n", "\n", "\n", "\n", "outer_cluster_49\n", "\n", "\n", "cluster_49\n", "\n", "\n", "\n", "outer_cluster_138\n", "\n", "\n", "cluster_138\n", "\n", "\n", "\n", "outer_cluster_185\n", "\n", "\n", "cluster_185\n", "\n", "\n", "\n", "outer_cluster_176\n", "\n", "\n", "cluster_176\n", "\n", "\n", "\n", "outer_cluster_65\n", "\n", "\n", "cluster_65\n", "\n", "\n", "\n", "outer_cluster_Int_to_py-7586556743040283621-value\n", "\n", "\n", "cluster_Int_to_py-7586556743040283621-value\n", "\n", "\n", "\n", "outer_cluster_Int_to_py-11951456526892775522-value\n", "\n", "\n", "cluster_Int_to_py-11951456526892775522-value\n", "\n", "\n", "\n", "outer_cluster_Int_to_py-6079675520328773069-value\n", "\n", "\n", "cluster_Int_to_py-6079675520328773069-value\n", "\n", "\n", "\n", "outer_cluster_Int_to_py-103947256882385308-value\n", "\n", "\n", "cluster_Int_to_py-103947256882385308-value\n", "\n", "\n", "\n", "outer_cluster_Int_to_py-5092353580987650850-value\n", "\n", "\n", "cluster_Int_to_py-5092353580987650850-value\n", "\n", "\n", "\n", "outer_cluster_Int_to_py-1870696621799859130-value\n", "\n", "\n", "cluster_Int_to_py-1870696621799859130-value\n", "\n", "\n", "\n", "outer_cluster_Boolean_to_py-155920885323577962-value\n", "\n", "\n", "cluster_Boolean_to_py-155920885323577962-value\n", "\n", "\n", "\n", "outer_cluster_Int_to_py-12938778466233897741-value\n", "\n", "\n", "cluster_Int_to_py-12938778466233897741-value\n", "\n", "\n", "\n", "outer_cluster_139\n", "\n", "\n", "cluster_139\n", "\n", "\n", "\n", "outer_cluster_177\n", "\n", "\n", "cluster_177\n", "\n", "\n", "\n", "outer_cluster_186\n", "\n", "\n", "cluster_186\n", "\n", "\n", "\n", "outer_cluster_43\n", "\n", "\n", "cluster_43\n", "\n", "\n", "\n", "outer_cluster_203\n", "\n", "\n", "cluster_203\n", "\n", "\n", "\n", "outer_cluster_46\n", "\n", "\n", "cluster_46\n", "\n", "\n", "\n", "outer_cluster_212\n", "\n", "\n", "cluster_212\n", "\n", "\n", "\n", "outer_cluster_42\n", "\n", "\n", "cluster_42\n", "\n", "\n", "\n", "outer_cluster_45\n", "\n", "\n", "cluster_45\n", "\n", "\n", "\n", "outer_cluster_38\n", "\n", "\n", "cluster_38\n", "\n", "\n", "\n", "outer_cluster_104\n", "\n", "\n", "cluster_104\n", "\n", "\n", "\n", "outer_cluster_112\n", "\n", "\n", "cluster_112\n", "\n", "\n", "\n", "outer_cluster_113\n", "\n", "\n", "cluster_113\n", "\n", "\n", "\n", "outer_cluster_165\n", "\n", "\n", "cluster_165\n", "\n", "\n", "\n", "outer_cluster_85\n", "\n", "\n", "cluster_85\n", "\n", "\n", "\n", "outer_cluster_124\n", "\n", "\n", "cluster_124\n", "\n", "\n", "\n", "outer_cluster_30\n", "\n", "\n", "cluster_30\n", "\n", "\n", "\n", "outer_cluster_19\n", "\n", "\n", "cluster_19\n", "\n", "\n", "\n", "outer_cluster_201\n", "\n", "\n", "cluster_201\n", "\n", "\n", "\n", "outer_cluster_198\n", "\n", "\n", "cluster_198\n", "\n", "\n", "\n", "outer_cluster_22\n", "\n", "\n", "cluster_22\n", "\n", "\n", "\n", "outer_cluster_greater_zero-1143242824664700181-value\n", "\n", "\n", "cluster_greater_zero-1143242824664700181-value\n", "\n", "\n", "\n", "outer_cluster_greater_zero-13770179520251441998-value\n", "\n", "\n", "cluster_greater_zero-13770179520251441998-value\n", "\n", "\n", "\n", "outer_cluster_greater_zero-14757501459592564217-value\n", "\n", "\n", "cluster_greater_zero-14757501459592564217-value\n", "\n", "\n", "\n", "outer_cluster_greater_zero-2598150418935018079-value\n", "\n", "\n", "cluster_greater_zero-2598150418935018079-value\n", "\n", "\n", "\n", "outer_cluster_greater_zero-12107377412216353484-value\n", "\n", "\n", "cluster_greater_zero-12107377412216353484-value\n", "\n", "\n", "\n", "outer_cluster_93\n", "\n", "\n", "cluster_93\n", "\n", "\n", "\n", "outer_cluster_156\n", "\n", "\n", "cluster_156\n", "\n", "\n", "\n", "outer_cluster_150\n", "\n", "\n", "cluster_150\n", "\n", "\n", "\n", "outer_cluster_200\n", "\n", "\n", "cluster_200\n", "\n", "\n", "\n", "outer_cluster_136\n", "\n", "\n", "cluster_136\n", "\n", "\n", "\n", "outer_cluster_210\n", "\n", "\n", "cluster_210\n", "\n", "\n", "\n", "outer_cluster_213\n", "\n", "\n", "cluster_213\n", "\n", "\n", "\n", "outer_cluster_35\n", "\n", "\n", "cluster_35\n", "\n", "\n", "\n", "outer_cluster_197\n", "\n", "\n", "cluster_197\n", "\n", "\n", "\n", "outer_cluster_174\n", "\n", "\n", "cluster_174\n", "\n", "\n", "\n", "outer_cluster_208\n", "\n", "\n", "cluster_208\n", "\n", "\n", "\n", "\n", "NDArray_dtype-15121139857639374588:s->assume_isfinite-10080759905092916392\n", "\n", "\n", "\n", "\n", "\n", "assume_isfinite-10080759905092916392:s->assume_shape-14591484260056516843\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_dtype-10080759905092916392:s->assume_shape-14591484260056516843\n", "\n", "\n", "\n", "\n", "\n", "assume_shape-14591484260056516843:s->assume_dtype-3429551472952562336\n", "\n", "\n", "\n", "\n", "\n", "assume_shape-14591484260056516843:s->NDArray_shape-15121139857639374588\n", "\n", "\n", "\n", "\n", "\n", "NDArray_dtype-11743562013128004906:s->assume_dtype-3429551472952562336\n", "\n", "\n", "\n", "\n", "\n", "assume_dtype-3429551472952562336:s->NDArray_dtype-10080759905092916392\n", "\n", "\n", "\n", "\n", "\n", "NDArray_device-15121139857639374588:s->asarray-9510298863856844727\n", "\n", "\n", "\n", "\n", "\n", "asarray-9510298863856844727:s->assume_isfinite-10080759905092916392\n", "\n", "\n", "\n", "\n", "\n", "Float___truediv__-12808993487988576005:s->Float_rational-0\n", "\n", "\n", "\n", "\n", "\n", "Float___truediv__-12808993487988576005:s->Float_from_int-11951456526892775522\n", "\n", "\n", "\n", "\n", "\n", "Float_from_int-11951456526892775522:s->Int___sub__-2601583573127157282\n", "\n", "\n", "\n", "\n", "\n", "Float_from_int-12938778466233897741:s->TupleInt_length-11379923615081194535\n", "\n", "\n", "\n", "\n", "\n", "TupleInt_length-11379923615081194535:s->NDArray_shape-7742477628363861583\n", "\n", "\n", "\n", "\n", "\n", "Float___truediv__-5949890542083451333:s->Float_from_int-12938778466233897741\n", "\n", "\n", "\n", "\n", "\n", "Float___truediv__-5949890542083451333:s->Float___truediv__-5949890542083451333\n", "\n", "\n", "\n", "\n", "\n", "Int___sub__-2601583573127157282:s->Int___init__-16347205588787662656\n", "\n", "\n", "\n", "\n", "\n", "Int___sub__-2601583573127157282:s->Int___init__-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "IndexKey_multi_axis-11068081844434038611:s->MultiAxisIndexKey___add__-7546443524583315781\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___add__-7546443524583315781:s->MultiAxisIndexKey___init__-9353306107957757443\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___add__-7546443524583315781:s->MultiAxisIndexKey___init__-17771263905015585321\n", "\n", "\n", "\n", "\n", "\n", "IndexKey_multi_axis-2961965818023366657:s->MultiAxisIndexKey___add__-9019874688858188702\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___add__-9019874688858188702:s->MultiAxisIndexKey___init__-9353306107957757443\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___add__-9019874688858188702:s->MultiAxisIndexKey___init__-9665147878604913367\n", "\n", "\n", "\n", "\n", "\n", "IndexKey_slice-4520820669176069863:s->Slice___init__-15501507093852132239\n", "\n", "\n", "\n", "\n", "\n", "Slice___init__-15501507093852132239:s->OptionalInt_some-11224002729757616573\n", "\n", "\n", "\n", "\n", "\n", "IndexKey_multi_axis-2650124047376210733:s->MultiAxisIndexKey___add__-4155431249018709085\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___add__-4155431249018709085:s->MultiAxisIndexKey___init__-9353306107957757443\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___add__-4155431249018709085:s->MultiAxisIndexKey___init__-4312926155411299247\n", "\n", "\n", "\n", "\n", "\n", "ndarray_index-7690503999922668929:s->NDArray___eq__-17968234112188297122\n", "\n", "\n", "\n", "\n", "\n", "NDArray___eq__-17968234112188297122:s->TupleNDArray___getitem__-10045558824545728354\n", "\n", "\n", "\n", "\n", "\n", "NDArray___eq__-17968234112188297122:s->NDArray___getitem__-6343722845416298339\n", "\n", "\n", "\n", "\n", "\n", "IndexKey_multi_axis-3689419615158525606:s->MultiAxisIndexKey___add__-10696952293987308628\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___add__-10696952293987308628:s->MultiAxisIndexKey___init__-9353306107957757443\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___add__-10696952293987308628:s->MultiAxisIndexKey___init__-10392601675740072316\n", "\n", "\n", "\n", "\n", "\n", "ndarray_index-10236680790416494354:s->NDArray___eq__-7887474207095380730\n", "\n", "\n", "\n", "\n", "\n", "NDArray___eq__-7887474207095380730:s->NDArray___getitem__-16424482750509214731\n", "\n", "\n", "\n", "\n", "\n", "NDArray___eq__-7887474207095380730:s->TupleNDArray___getitem__-10045558824545728354\n", "\n", "\n", "\n", "\n", "\n", "ndarray_index-4468847040734877209:s->NDArray___eq__-3677844317228415595\n", "\n", "\n", "\n", "\n", "\n", "NDArray___eq__-3677844317228415595:s->NDArray_scalar-3845340500482103568\n", "\n", "\n", "\n", "\n", "\n", "NDArray___eq__-3677844317228415595:s->std-4851945112178408602\n", "\n", "\n", "\n", "\n", "\n", "IndexKey_int-12938778466233897741:s->Int___sub__-11477953740632672431\n", "\n", "\n", "\n", "\n", "\n", "Int___sub__-11477953740632672431:s->TupleValue_length-883374682458736911\n", "\n", "\n", "\n", "\n", "\n", "Int___sub__-11477953740632672431:s->Int___add__-17495654355659155035\n", "\n", "\n", "\n", "\n", "\n", "ndarray_index-9457253364840142751:s->NDArray___eq__-5948126446311695931\n", "\n", "\n", "\n", "\n", "\n", "NDArray___eq__-5948126446311695931:s->asarray-17776165865978447989\n", "\n", "\n", "\n", "\n", "\n", "NDArray___eq__-5948126446311695931:s->NDArray_scalar-3845340500482103568\n", "\n", "\n", "\n", "\n", "\n", "ndarray_index-1091269196223507527:s->NDArray___eq__-14314110614928331155\n", "\n", "\n", "\n", "\n", "\n", "NDArray___eq__-14314110614928331155:s->NDArray_scalar-14757501459592564217\n", "\n", "\n", "\n", "\n", "\n", "NDArray___eq__-14314110614928331155:s->assume_value_one_of-5323778840018127892\n", "\n", "\n", "\n", "\n", "\n", "IndexKey_multi_axis-5456168980075999428:s->MultiAxisIndexKey___add__-8188473634840644445\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___add__-8188473634840644445:s->MultiAxisIndexKey___init__-9353306107957757443\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___add__-8188473634840644445:s->MultiAxisIndexKey___init__-12159351040657546138\n", "\n", "\n", "\n", "\n", "\n", "NDArray_shape-7742477628363861583:s->reshape-4112525690760736104\n", "\n", "\n", "\n", "\n", "\n", "TupleValue_length-51973628441192654:s->TupleValue___init__-14757501459592564217\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___init__-14757501459592564217:s->TupleValue___getitem__-4148863126349750477\n", "\n", "\n", "\n", "\n", "\n", "TupleValue_length-883374682458736911:s->TupleValue___init__-3845340500482103568\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___init__-3845340500482103568:s->TupleValue___getitem__-7786309113067083429\n", "\n", "\n", "\n", "\n", "\n", "Int___add__-17495654355659155035:s->Int___sub__-11477953740632672431\n", "\n", "\n", "\n", "\n", "\n", "Int___add__-17495654355659155035:s->Int___init__-5871781006564002453\n", "\n", "\n", "\n", "\n", "\n", "Value_to_int-7118971088111087942:s->NDArray_to_value-1247190081547085489\n", "\n", "\n", "\n", "\n", "\n", "NDArray_to_value-1247190081547085489:s->sum-1681433789052220133\n", "\n", "\n", "\n", "\n", "\n", "Value_to_int-5352221723193614120:s->NDArray_to_value-17927184790339163283\n", "\n", "\n", "\n", "\n", "\n", "NDArray_to_value-17927184790339163283:s->sum-1955564354691009820\n", "\n", "\n", "\n", "\n", "\n", "TupleValue_length-467762655970733886:s->TupleValue___add__-15259460202689358531\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___add__-15259460202689358531:s->TupleValue___init__-14757501459592564217\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___add__-15259460202689358531:s->TupleValue___init__-3845340500482103568\n", "\n", "\n", "\n", "\n", "\n", "TupleValue_length-18083105675662741245:s->possible_values-12211324669098738792\n", "\n", "\n", "\n", "\n", "\n", "possible_values-12211324669098738792:s->NDArray_index-12579319251068649370\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___getitem__-14078601210367663714:s->NDArray_shape-12782857580910319779\n", "\n", "\n", "\n", "\n", "\n", "NDArray_shape-12782857580910319779:s->unique_values-12782857580910319779\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___getitem__-11605336705429392564:s->NDArray_shape-10080759905092916392\n", "\n", "\n", "\n", "\n", "\n", "NDArray_shape-10080759905092916392:s->assume_shape-14591484260056516843\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___getitem__-12686509587440430679:s->TupleInt___init__-103947256882385308\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___init__-103947256882385308:s->Int___init__-6755155689022739364\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___getitem__-7967890718712059612:s->TupleValue_length-51973628441192654\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___getitem__-7967890718712059612:s->TupleInt___add__-13243224121832505654\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___add__-13243224121832505654:s->TupleInt___init__-103947256882385308\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___add__-13243224121832505654:s->NDArray_shape-1714775736476281168\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___init__-12159351040657546138:s->MultiAxisIndexKeyItem_slice-6287570034093543685\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___init__-9665147878604913367:s->MultiAxisIndexKeyItem_slice-3793366872040910914\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKeyItem_slice-6287570034093543685:s->Slice___init__-14445438978175812750\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___init__-17771263905015585321:s->MultiAxisIndexKeyItem_int-12938778466233897741\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKeyItem_int-12938778466233897741:s->TupleValue_length-883374682458736911\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___init__-10392601675740072316:s->MultiAxisIndexKeyItem_slice-4520820669176069863\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKeyItem_slice-4520820669176069863:s->Slice___init__-15501507093852132239\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKeyItem_slice-3793366872040910914:s->Slice___init__-1162291712589082458\n", "\n", "\n", "\n", "\n", "\n", "Slice___init__-14445438978175812750:s->OptionalInt_some-12990752094675090395\n", "\n", "\n", "\n", "\n", "\n", "Slice___init__-1162291712589082458:s->OptionalInt_some-12938778466233897741\n", "\n", "\n", "\n", "\n", "\n", "NDArray_scalar-11120055472875231265:s->Value_float-5248274466311228812\n", "\n", "\n", "\n", "\n", "\n", "Value_float-5248274466311228812:s->Float_rational-17615343019692007359\n", "\n", "\n", "\n", "\n", "\n", "asarray-7902703286805427734:s->asarray-7902703286805427734\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-17483047916985507424:s->IndexKey_multi_axis-2650124047376210733\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-17483047916985507424:s->NDArray___setitem__-18325169333216085054\n", "\n", "\n", "\n", "\n", "\n", "NDArray___setitem__-18325169333216085054:s->IndexKey_multi_axis-11068081844434038611\n", "\n", "\n", "\n", "\n", "\n", "NDArray___setitem__-18325169333216085054:s->NDArray___setitem__-7453141863274628760\n", "\n", "\n", "\n", "\n", "\n", "NDArray___setitem__-18325169333216085054:s->mean-3476503888447580293\n", "\n", "\n", "\n", "\n", "\n", "mean-9206860573968271485:s->NDArray___getitem__-16307929054953181812\n", "\n", "\n", "\n", "\n", "\n", "mean-9206860573968271485:s->OptionalIntOrTuple_some-6859102945905124672\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-16307929054953181812:s->asarray-9510298863856844727\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-16307929054953181812:s->ndarray_index-7690503999922668929\n", "\n", "\n", "\n", "\n", "\n", "concat-9071020324919791953:s->TupleNDArray___add__-17612194977553982959\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___add__-17612194977553982959:s->TupleNDArray___init__-14497633317386600947\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___add__-17612194977553982959:s->TupleNDArray___init__-6131649148769965723\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-18135092377765138894:s->TupleValue_length-51973628441192654\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-18135092377765138894:s->svd-7253966389981509278\n", "\n", "\n", "\n", "\n", "\n", "svd-7253966389981509278:s->NDArray___mul__-8455018010728142919\n", "\n", "\n", "\n", "\n", "\n", "NDArray_T-2858018561140981349:s->NDArray___truediv__-11279504549742320031\n", "\n", "\n", "\n", "\n", "\n", "NDArray___truediv__-11279504549742320031:s->NDArray___setitem__-5767087113385015795\n", "\n", "\n", "\n", "\n", "\n", "NDArray___truediv__-11279504549742320031:s->NDArray___getitem__-9914932780259612220\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-18178625676753040942:s->assume_isfinite-10080759905092916392\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-18178625676753040942:s->ndarray_index-1091269196223507527\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-11026489642259430172:s->IndexKey_multi_axis-2961965818023366657\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-11026489642259430172:s->NDArray___matmul__-7132500556515696557\n", "\n", "\n", "\n", "\n", "\n", "NDArray___matmul__-7132500556515696557:s->NDArray___sub__-8877293197236476153\n", "\n", "\n", "\n", "\n", "\n", "NDArray___matmul__-7132500556515696557:s->NDArray___matmul__-10968585808826125111\n", "\n", "\n", "\n", "\n", "\n", "NDArray___truediv__-15656725660214344740:s->astype-6261542238027864055\n", "\n", "\n", "\n", "\n", "\n", "NDArray___truediv__-15656725660214344740:s->NDArray_scalar-2598150418935018079\n", "\n", "\n", "\n", "\n", "\n", "astype-6261542238027864055:s->NDArray_dtype-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "astype-6261542238027864055:s->TupleNDArray___getitem__-15957548086918070248\n", "\n", "\n", "\n", "\n", "\n", "NDArray_scalar-2598150418935018079:s->Value_float-15173113486080567242\n", "\n", "\n", "\n", "\n", "\n", "NDArray___matmul__-9557034512502171054:s->NDArray___setitem__-18325169333216085054\n", "\n", "\n", "\n", "\n", "\n", "NDArray___matmul__-9557034512502171054:s->NDArray___truediv__-15656725660214344740\n", "\n", "\n", "\n", "\n", "\n", "NDArray___truediv__-9788377807842481490:s->concat-9071020324919791953\n", "\n", "\n", "\n", "\n", "\n", "NDArray___truediv__-9788377807842481490:s->NDArray___setitem__-5767087113385015795\n", "\n", "\n", "\n", "\n", "\n", "NDArray___setitem__-5767087113385015795:s->ndarray_index-4468847040734877209\n", "\n", "\n", "\n", "\n", "\n", "NDArray___setitem__-5767087113385015795:s->NDArray_scalar-12107377412216353484\n", "\n", "\n", "\n", "\n", "\n", "NDArray___setitem__-5767087113385015795:s->std-4851945112178408602\n", "\n", "\n", "\n", "\n", "\n", "NDArray_T-10444575304181264970:s->NDArray___mul__-7696624279617524538\n", "\n", "\n", "\n", "\n", "\n", "NDArray___mul__-7696624279617524538:s->NDArray_T-17147757364762811680\n", "\n", "\n", "\n", "\n", "\n", "NDArray___mul__-7696624279617524538:s->ndarray-sqrt-5404195351634806774\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-9914932780259612220:s->IndexKey_multi_axis-3689419615158525606\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-9914932780259612220:s->TupleNDArray___getitem__-1818913068061409678\n", "\n", "\n", "\n", "\n", "\n", "NDArray_scalar-14757501459592564217:s->TupleValue___getitem__-9658389681233211557\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___getitem__-9658389681233211557:s->TupleValue___init__-14757501459592564217\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-16424482750509214731:s->IndexKey_int-12938778466233897741\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-16424482750509214731:s->concat-430064524623572644\n", "\n", "\n", "\n", "\n", "\n", "concat-430064524623572644:s->TupleNDArray___init__-12782857580910319779\n", "\n", "\n", "\n", "\n", "\n", "NDArray___truediv__-3215265837560371319:s->NDArray_T-2858018561140981349\n", "\n", "\n", "\n", "\n", "\n", "NDArray___truediv__-3215265837560371319:s->NDArray___getitem__-11494903289568215254\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-11494903289568215254:s->IndexKey_slice-4520820669176069863\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-11494903289568215254:s->TupleNDArray___getitem__-18135092377765138894\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-1818913068061409678:s->TupleValue_length-467762655970733886\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-1818913068061409678:s->svd-7253966389981509278\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-2205987174022554874:s->IndexKey_multi_axis-11068081844434038611\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-2205987174022554874:s->NDArray___setitem__-18325169333216085054\n", "\n", "\n", "\n", "\n", "\n", "NDArray___gt__-15651908559655936539:s->TupleNDArray___getitem__-18135092377765138894\n", "\n", "\n", "\n", "\n", "\n", "NDArray___gt__-15651908559655936539:s->NDArray_scalar-1143242824664700181\n", "\n", "\n", "\n", "\n", "\n", "NDArray___gt__-8664676620264668937:s->TupleNDArray___getitem__-17539377729349800285\n", "\n", "\n", "\n", "\n", "\n", "NDArray___gt__-8664676620264668937:s->NDArray___mul__-8440009558605893705\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-17539377729349800285:s->TupleValue_length-883374682458736911\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-17539377729349800285:s->svd-2189404700831293460\n", "\n", "\n", "\n", "\n", "\n", "NDArray___mul__-8440009558605893705:s->NDArray_scalar-1143242824664700181\n", "\n", "\n", "\n", "\n", "\n", "NDArray___mul__-8440009558605893705:s->NDArray___getitem__-17758114586016463110\n", "\n", "\n", "\n", "\n", "\n", "asarray-17776165865978447989:s->reshape-4112525690760736104\n", "\n", "\n", "\n", "\n", "\n", "NDArray_scalar-3845340500482103568:s->TupleValue___getitem__-1353837537593392198\n", "\n", "\n", "\n", "\n", "\n", "sum-1955564354691009820:s->astype-14592420363448682842\n", "\n", "\n", "\n", "\n", "\n", "astype-14592420363448682842:s->NDArray___gt__-15651908559655936539\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-13476223401931994896:s->IndexKey_multi_axis-5456168980075999428\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-13476223401931994896:s->NDArray_T-15484955256727723166\n", "\n", "\n", "\n", "\n", "\n", "NDArray_T-15484955256727723166:s->TupleNDArray___getitem__-10680274783444675613\n", "\n", "\n", "\n", "\n", "\n", "NDArray_shape-15121139857639374588:s->assume_isfinite-10080759905092916392\n", "\n", "\n", "\n", "\n", "\n", "NDArray___setitem__-7453141863274628760:s->IndexKey_multi_axis-2650124047376210733\n", "\n", "\n", "\n", "\n", "\n", "NDArray___setitem__-7453141863274628760:s->mean-9206860573968271485\n", "\n", "\n", "\n", "\n", "\n", "NDArray___setitem__-7453141863274628760:s->zeros-16505489609336576318\n", "\n", "\n", "\n", "\n", "\n", "mean-3476503888447580293:s->OptionalIntOrTuple_some-6859102945905124672\n", "\n", "\n", "\n", "\n", "\n", "mean-3476503888447580293:s->NDArray___getitem__-3836913244690017957\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-10680274783444675613:s->TupleValue_length-18083105675662741245\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-10680274783444675613:s->svd-2189404700831293460\n", "\n", "\n", "\n", "\n", "\n", "astype-12468708834165933853:s->NDArray___gt__-8664676620264668937\n", "\n", "\n", "\n", "\n", "\n", "NDArray_scalar-15588902513610108474:s->NDArray_index-1182067134106770624\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-1182067134106770624:s->TupleNDArray___getitem__-17539377729349800285\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-17758114586016463110:s->TupleNDArray___getitem__-17539377729349800285\n", "\n", "\n", "\n", "\n", "\n", "Value_float-15173113486080567242:s->Float_rational-5871781006564002453\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-3836913244690017957:s->assume_isfinite-10080759905092916392\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-3836913244690017957:s->ndarray_index-10236680790416494354\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-10045558824545728354:s->TupleInt_length-11379923615081194535\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-10045558824545728354:s->unique_inverse-7742477628363861583\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-6343722845416298339:s->NDArray_vector-467762655970733886\n", "\n", "\n", "\n", "\n", "\n", "NDArray_scalar-12107377412216353484:s->Value_float-6235596405652351031\n", "\n", "\n", "\n", "\n", "\n", "Value_float-6235596405652351031:s->Float___init__-10858178701590265856\n", "\n", "\n", "\n", "\n", "\n", "sum-1681433789052220133:s->astype-12468708834165933853\n", "\n", "\n", "\n", "\n", "\n", "NDArray___matmul__-13392291772433010205:s->NDArray_T-10444575304181264970\n", "\n", "\n", "\n", "\n", "\n", "NDArray___matmul__-13392291772433010205:s->NDArray___truediv__-3215265837560371319\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-15957548086918070248:s->Int___sub__-11477953740632672431\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-15957548086918070248:s->unique_counts-7742477628363861583\n", "\n", "\n", "\n", "\n", "\n", "reshape-4112525690760736104:s->reshape-4112525690760736104\n", "\n", "\n", "\n", "\n", "\n", "assume_value_one_of-5323778840018127892:s->TupleValue___add__-15259460202689358531\n", "\n", "\n", "\n", "\n", "\n", "assume_value_one_of-5323778840018127892:s->assume_shape-8316602628326787375\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "assume_shape-8316602628326787375:s->TupleInt___init__-1870696621799859130\n", "\n", "\n", "\n", "\n", "\n", "std-4851945112178408602:s->OptionalIntOrTuple_some-6859102945905124672\n", "\n", "\n", "\n", "\n", "\n", "std-4851945112178408602:s->concat-9071020324919791953\n", "\n", "\n", "\n", "\n", "\n", "svd-2189404700831293460:s->NDArray___matmul__-13392291772433010205\n", "\n", "\n", "\n", "\n", "\n", "unique_counts-7742477628363861583:s->reshape-4112525690760736104\n", "\n", "\n", "\n", "\n", "\n", "NDArray___mul__-8455018010728142919:s->NDArray___truediv__-9788377807842481490\n", "\n", "\n", "\n", "\n", "\n", "NDArray___mul__-8455018010728142919:s->ndarray-sqrt-4416873412293684555\n", "\n", "\n", "\n", "\n", "\n", "ndarray-sqrt-4416873412293684555:s->NDArray_scalar-11120055472875231265\n", "\n", "\n", "\n", "\n", "\n", "zeros-16505489609336576318:s->TupleInt___add__-10752996994297486686\n", "\n", "\n", "\n", "\n", "\n", "zeros-16505489609336576318:s->OptionalDType_some-3429551472952562336\n", "\n", "\n", "\n", "\n", "\n", "zeros-16505489609336576318:s->OptionalDevice_some-5144327209428843504\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___add__-10752996994297486686:s->TupleInt___init__-103947256882385308\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___add__-10752996994297486686:s->TupleInt___init__-6079675520328773069\n", "\n", "\n", "\n", "\n", "\n", "OptionalDType_some-3429551472952562336:s->NDArray_dtype-11743562013128004906\n", "\n", "\n", "\n", "\n", "\n", "OptionalDevice_some-5144327209428843504:s->NDArray_device-15121139857639374588\n", "\n", "\n", "\n", "\n", "\n", "NDArray_scalar-13770179520251441998:s->Value_int-1870696621799859130\n", "\n", "\n", "\n", "\n", "\n", "Value_int-1870696621799859130:s->Int___init__-16347205588787662656\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-9812641508136405718:s->asarray-9510298863856844727\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-9812641508136405718:s->ndarray_index-9457253364840142751\n", "\n", "\n", "\n", "\n", "\n", "NDArray___mul__-4099386548708531027:s->NDArray___truediv__-15656725660214344740\n", "\n", "\n", "\n", "\n", "\n", "NDArray___mul__-4099386548708531027:s->NDArray_scalar-13770179520251441998\n", "\n", "\n", "\n", "\n", "\n", "NDArray_T-17147757364762811680:s->NDArray___sub__-1374586120005010617\n", "\n", "\n", "\n", "\n", "\n", "NDArray___sub__-1374586120005010617:s->NDArray___setitem__-18325169333216085054\n", "\n", "\n", "\n", "\n", "\n", "NDArray___sub__-1374586120005010617:s->NDArray___matmul__-9557034512502171054\n", "\n", "\n", "\n", "\n", "\n", "NDArray___sub__-8877293197236476153:s->assume_isfinite-10080759905092916392\n", "\n", "\n", "\n", "\n", "\n", "NDArray___sub__-8877293197236476153:s->NDArray___matmul__-9557034512502171054\n", "\n", "\n", "\n", "\n", "\n", "NDArray___sub__-9558798608273926456:s->NDArray___getitem__-18178625676753040942\n", "\n", "\n", "\n", "\n", "\n", "NDArray___sub__-9558798608273926456:s->NDArray___getitem__-2205987174022554874\n", "\n", "\n", "\n", "\n", "\n", "NDArray___matmul__-10968585808826125111:s->NDArray___truediv__-3215265837560371319\n", "\n", "\n", "\n", "\n", "\n", "NDArray___matmul__-10968585808826125111:s->NDArray___getitem__-13476223401931994896\n", "\n", "\n", "\n", "\n", "\n", "unique_inverse-7742477628363861583:s->assume_value_one_of-5323778840018127892\n", "\n", "\n", "\n", "\n", "\n", "ndarray-sqrt-5404195351634806774:s->NDArray___mul__-3756686807776082277\n", "\n", "\n", "\n", "\n", "\n", "NDArray___mul__-3756686807776082277:s->NDArray_scalar-12107377412216353484\n", "\n", "\n", "\n", "\n", "\n", "NDArray___mul__-3756686807776082277:s->NDArray___mul__-4099386548708531027\n", "\n", "\n", "\n", "\n", "\n", "NDArray___sub__-10430407918099810154:s->NDArray___getitem__-17483047916985507424\n", "\n", "\n", "\n", "\n", "\n", "NDArray___sub__-10430407918099810154:s->NDArray___getitem__-9812641508136405718\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___init__-12782857580910319779:s->unique_values-12782857580910319779\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-13683004811263061306:s->unique_inverse-7742477628363861583\n", "\n", "\n", "\n", "\n", "\n", "unique_values-12782857580910319779:s->NDArray_vector-18083105675662741245\n", "\n", "\n", "\n", "\n", "\n", "NDArray_vector-18083105675662741245:s->possible_values-12211324669098738792\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "unique_values-7742477628363861583:s->asarray-17776165865978447989\n", "\n", "\n", "\n", "\n", "\n", "NDArray_vector-467762655970733886:s->possible_values-13042725723116283049\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "possible_values-13042725723116283049:s->NDArray_index-17067340853146132798\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___getitem__-1353837537593392198:s->possible_values-12211324669098738792\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___init__-1870696621799859130:s->TupleInt___getitem__-11605336705429392564\n", "\n", "\n", "\n", "\n", "\n", "OptionalInt_some-11224002729757616573:s->Value_to_int-5352221723193614120\n", "\n", "\n", "\n", "\n", "\n", "OptionalInt_some-12938778466233897741:s->Int___init__-5871781006564002453\n", "\n", "\n", "\n", "\n", "\n", "OptionalInt_some-12990752094675090395:s->Value_to_int-7118971088111087942\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-11951456526892775522:s->Int___sub__-2601583573127157282\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-6079675520328773069:s->TupleValue_length-18083105675662741245\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-103947256882385308:s->TupleInt___getitem__-12686509587440430679\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-1870696621799859130:s->TupleInt___getitem__-11605336705429392564\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-12938778466233897741:s->TupleValue_length-883374682458736911\n", "\n", "\n", "\n", "\n", "\n", "NDArray_shape-1714775736476281168:s->assume_shape-8316602628326787375\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___init__-6079675520328773069:s->TupleValue_length-467762655970733886\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___init__-12938778466233897741:s->TupleInt_length-11379923615081194535\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___init__-14497633317386600947:s->NDArray___sub__-10430407918099810154\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___init__-6131649148769965723:s->NDArray___sub__-9558798608273926456\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___getitem__-7786309113067083429:s->TupleValue___add__-15259460202689358531\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-12579319251068649370:s->concat-430064524623572644\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-17067340853146132798:s->assume_shape-8316602628326787375\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___getitem__-4148863126349750477:s->TupleInt_length-11379923615081194535\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___getitem__-4148863126349750477:s->possible_values-13042725723116283049\n", "\n", "\n", "\n", "\n", "\n", "greater_zero-13770179520251441998:s->Value_int-1870696621799859130\n", "\n", "\n", "\n", "\n", "\n", "greater_zero-14757501459592564217:s->TupleValue___getitem__-14448359888109329694\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___getitem__-14448359888109329694:s->Int___sub__-11477953740632672431\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___getitem__-14448359888109329694:s->possible_values-12211324669098738792\n", "\n", "\n", "\n", "\n", "\n", "greater_zero-2598150418935018079:s->Value_float-15173113486080567242\n", "\n", "\n", "\n", "\n", "\n", "greater_zero-12107377412216353484:s->Value_float-6235596405652351031\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-3712217405396014230:s->sum-1955564354691009820\n", "\n", "\n", "\n", "\n", "\n", "Value_int-12938778466233897741:s->Int___init__-5871781006564002453\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-10864543514592368202:s->NDArray_vector-467762655970733886\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-10864543514592368202:s->TupleInt___init__-12938778466233897741\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-15769018209198649053:s->asarray-17776165865978447989\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-6690955771313385503:s->sum-1681433789052220133\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-16788298149597563309:s->NDArray_vector-18083105675662741245\n", "\n", "\n", "\n", "\n", "\n", "NDArray_dtype-15121139857639374588\n", "\n", "\n", "NDArray_dtype\n", "\n", "\n", "\n", "\n", "\n", "\n", "assume_isfinite-10080759905092916392\n", "\n", "\n", "assume_isfinite\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_dtype-10080759905092916392\n", "\n", "\n", "NDArray_dtype\n", "\n", "\n", "\n", "\n", "\n", "\n", "assume_shape-14591484260056516843\n", "\n", "\n", "assume_shape\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_dtype-11743562013128004906\n", "\n", "\n", "NDArray_dtype\n", "\n", "\n", "\n", "\n", "\n", "\n", "assume_dtype-3429551472952562336\n", "\n", "\n", "assume_dtype(NDArray_var("X"), ·)\n", "\n", "\n", "\n", "\n", "\n", "\n", "DType_float64-0\n", "\n", "\n", "DType_float64\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_device-15121139857639374588\n", "\n", "\n", "NDArray_device\n", "\n", "\n", "\n", "\n", "\n", "\n", "asarray-9510298863856844727\n", "\n", "\n", "asarray(·, OptionalDType_none, OptionalBool_none)\n", "\n", "\n", "\n", "\n", "\n", "\n", "Float___truediv__-12808993487988576005\n", "\n", "\n", "Float___truediv__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Float_rational-0\n", "\n", "\n", "Float_rational((rational 1 1))\n", "\n", "\n", "\n", "\n", "\n", "\n", "Float_from_int-11951456526892775522\n", "\n", "\n", "Float_from_int\n", "\n", "\n", "\n", "\n", "\n", "\n", "Float_rational-17615343019692007359\n", "\n", "\n", "Float_rational((rational 1 999998))\n", "\n", "\n", "\n", "\n", "\n", "\n", "Float_from_int-12938778466233897741\n", "\n", "\n", "Float_from_int\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleInt_length-11379923615081194535\n", "\n", "\n", "TupleInt_length\n", "\n", "\n", "\n", "\n", "\n", "\n", "Float___truediv__-5949890542083451333\n", "\n", "\n", "Float___truediv__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Float___init__-10858178701590265856\n", "\n", "\n", "Float___init__(1.0)\n", "\n", "\n", "\n", "\n", "\n", "\n", "Float___init__-15726603433882419200\n", "\n", "\n", "Float___init__(1000000.0)\n", "\n", "\n", "\n", "\n", "\n", "\n", "Float_rational-5871781006564002453\n", "\n", "\n", "Float_rational((rational 1000000 1))\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int___sub__-2601583573127157282\n", "\n", "\n", "Int___sub__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Float_rational-11743562013128004906\n", "\n", "\n", "Float_rational((rational 999998 1))\n", "\n", "\n", "\n", "\n", "\n", "\n", "IndexKey_multi_axis-11068081844434038611\n", "\n", "\n", "IndexKey_multi_axis\n", "\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___add__-7546443524583315781\n", "\n", "\n", "MultiAxisIndexKey___add__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IndexKey_multi_axis-2961965818023366657\n", "\n", "\n", "IndexKey_multi_axis\n", "\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___add__-9019874688858188702\n", "\n", "\n", "MultiAxisIndexKey___add__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IndexKey_slice-4520820669176069863\n", "\n", "\n", "IndexKey_slice\n", "\n", "\n", "\n", "\n", "\n", "\n", "Slice___init__-15501507093852132239\n", "\n", "\n", "Slice___init__(OptionalInt_none, ·, OptionalInt_none)\n", "\n", "\n", "\n", "\n", "\n", "\n", "IndexKey_multi_axis-2650124047376210733\n", "\n", "\n", "IndexKey_multi_axis\n", "\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___add__-4155431249018709085\n", "\n", "\n", "MultiAxisIndexKey___add__\n", "\n", "\n", "\n", "\n", "\n", "\n", "ndarray_index-7690503999922668929\n", "\n", "\n", "ndarray_index\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___eq__-17968234112188297122\n", "\n", "\n", "NDArray___eq__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IndexKey_multi_axis-3689419615158525606\n", "\n", "\n", "IndexKey_multi_axis\n", "\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___add__-10696952293987308628\n", "\n", "\n", "MultiAxisIndexKey___add__\n", "\n", "\n", "\n", "\n", "\n", "\n", "ndarray_index-10236680790416494354\n", "\n", "\n", "ndarray_index\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___eq__-7887474207095380730\n", "\n", "\n", "NDArray___eq__\n", "\n", "\n", "\n", "\n", "\n", "\n", "ndarray_index-4468847040734877209\n", "\n", "\n", "ndarray_index\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___eq__-3677844317228415595\n", "\n", "\n", "NDArray___eq__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IndexKey_int-12938778466233897741\n", "\n", "\n", "IndexKey_int\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int___sub__-11477953740632672431\n", "\n", "\n", "Int___sub__\n", "\n", "\n", "\n", "\n", "\n", "\n", "ndarray_index-9457253364840142751\n", "\n", "\n", "ndarray_index\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___eq__-5948126446311695931\n", "\n", "\n", "NDArray___eq__\n", "\n", "\n", "\n", "\n", "\n", "\n", "ndarray_index-1091269196223507527\n", "\n", "\n", "ndarray_index\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___eq__-14314110614928331155\n", "\n", "\n", "NDArray___eq__\n", "\n", "\n", "\n", "\n", "\n", "\n", "IndexKey_multi_axis-5456168980075999428\n", "\n", "\n", "IndexKey_multi_axis\n", "\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___add__-8188473634840644445\n", "\n", "\n", "MultiAxisIndexKey___add__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int___init__-16347205588787662656\n", "\n", "\n", "Int___init__(1000000)\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int___init__-11743562013128004906\n", "\n", "\n", "Int___init__(2)\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int___init__-4603643575659657750\n", "\n", "\n", "Int___init__(999998)\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_shape-7742477628363861583\n", "\n", "\n", "NDArray_shape\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleValue_length-51973628441192654\n", "\n", "\n", "TupleValue_length\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___init__-14757501459592564217\n", "\n", "\n", "TupleValue___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleValue_length-883374682458736911\n", "\n", "\n", "TupleValue_length\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___init__-3845340500482103568\n", "\n", "\n", "TupleValue___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int___add__-17495654355659155035\n", "\n", "\n", "Int___add__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int___init__-5871781006564002453\n", "\n", "\n", "Int___init__(1)\n", "\n", "\n", "\n", "\n", "\n", "\n", "Value_to_int-7118971088111087942\n", "\n", "\n", "Value_to_int\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_to_value-1247190081547085489\n", "\n", "\n", "NDArray_to_value\n", "\n", "\n", "\n", "\n", "\n", "\n", "Value_to_int-5352221723193614120\n", "\n", "\n", "Value_to_int\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_to_value-17927184790339163283\n", "\n", "\n", "NDArray_to_value\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleValue_length-467762655970733886\n", "\n", "\n", "TupleValue_length\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___add__-15259460202689358531\n", "\n", "\n", "TupleValue___add__\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleValue_length-18083105675662741245\n", "\n", "\n", "TupleValue_length\n", "\n", "\n", "\n", "\n", "\n", "\n", "possible_values-12211324669098738792\n", "\n", "\n", "possible_values\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___getitem__-14078601210367663714\n", "\n", "\n", "TupleInt___getitem__(·, Int___init__(0))\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_shape-12782857580910319779\n", "\n", "\n", "NDArray_shape\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___getitem__-11605336705429392564\n", "\n", "\n", "TupleInt___getitem__(·, Int___init__(0))\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_shape-10080759905092916392\n", "\n", "\n", "NDArray_shape\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___getitem__-12686509587440430679\n", "\n", "\n", "TupleInt___getitem__(·, Int___init__(0))\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___init__-103947256882385308\n", "\n", "\n", "TupleInt___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___getitem__-7967890718712059612\n", "\n", "\n", "TupleInt___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___add__-13243224121832505654\n", "\n", "\n", "TupleInt___add__\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int___init__-6755155689022739364\n", "\n", "\n", "Int___init__(20)\n", "\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___init__-9353306107957757443\n", "\n", "\n", "MultiAxisIndexKey___init__(MultiAxisIndexKeyItem_slice(Slice___init__(OptionalInt_none, OptionalInt_none, OptionalInt_none)))\n", "\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___init__-12159351040657546138\n", "\n", "\n", "MultiAxisIndexKey___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___init__-9665147878604913367\n", "\n", "\n", "MultiAxisIndexKey___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKeyItem_slice-6287570034093543685\n", "\n", "\n", "MultiAxisIndexKeyItem_slice\n", "\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___init__-17771263905015585321\n", "\n", "\n", "MultiAxisIndexKey___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKeyItem_int-12938778466233897741\n", "\n", "\n", "MultiAxisIndexKeyItem_int\n", "\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___init__-4312926155411299247\n", "\n", "\n", "MultiAxisIndexKey___init__(MultiAxisIndexKeyItem_int(Int___init__(0)))\n", "\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKey___init__-10392601675740072316\n", "\n", "\n", "MultiAxisIndexKey___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKeyItem_slice-4520820669176069863\n", "\n", "\n", "MultiAxisIndexKeyItem_slice\n", "\n", "\n", "\n", "\n", "\n", "\n", "MultiAxisIndexKeyItem_slice-3793366872040910914\n", "\n", "\n", "MultiAxisIndexKeyItem_slice\n", "\n", "\n", "\n", "\n", "\n", "\n", "Slice___init__-14445438978175812750\n", "\n", "\n", "Slice___init__(OptionalInt_none, ·, OptionalInt_none)\n", "\n", "\n", "\n", "\n", "\n", "\n", "Slice___init__-1162291712589082458\n", "\n", "\n", "Slice___init__(OptionalInt_none, ·, OptionalInt_none)\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_scalar-11120055472875231265\n", "\n", "\n", "NDArray_scalar\n", "\n", "\n", "\n", "\n", "\n", "\n", "Value_float-5248274466311228812\n", "\n", "\n", "Value_float\n", "\n", "\n", "\n", "\n", "\n", "\n", "asarray-7902703286805427734\n", "\n", "\n", "asarray(·, OptionalDType_none, OptionalBool_none)\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-17483047916985507424\n", "\n", "\n", "NDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___setitem__-18325169333216085054\n", "\n", "\n", "NDArray___setitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "mean-9206860573968271485\n", "\n", "\n", "mean(·, ·, FALSE)\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-16307929054953181812\n", "\n", "\n", "NDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "OptionalIntOrTuple_some-6859102945905124672\n", "\n", "\n", "OptionalIntOrTuple_some(IntOrTuple_int(Int___init__(0)))\n", "\n", "\n", "\n", "\n", "\n", "\n", "concat-9071020324919791953\n", "\n", "\n", "concat(·, OptionalInt_some(Int___init__(0)))\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___add__-17612194977553982959\n", "\n", "\n", "TupleNDArray___add__\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-18135092377765138894\n", "\n", "\n", "TupleNDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "svd-7253966389981509278\n", "\n", "\n", "svd(·, FALSE)\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_T-2858018561140981349\n", "\n", "\n", "NDArray_T\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___truediv__-11279504549742320031\n", "\n", "\n", "NDArray___truediv__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-18178625676753040942\n", "\n", "\n", "NDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-11026489642259430172\n", "\n", "\n", "NDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___matmul__-7132500556515696557\n", "\n", "\n", "NDArray___matmul__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___truediv__-15656725660214344740\n", "\n", "\n", "NDArray___truediv__\n", "\n", "\n", "\n", "\n", "\n", "\n", "astype-6261542238027864055\n", "\n", "\n", "astype\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_scalar-2598150418935018079\n", "\n", "\n", "NDArray_scalar\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___matmul__-9557034512502171054\n", "\n", "\n", "NDArray___matmul__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___truediv__-9788377807842481490\n", "\n", "\n", "NDArray___truediv__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___setitem__-5767087113385015795\n", "\n", "\n", "NDArray___setitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_T-10444575304181264970\n", "\n", "\n", "NDArray_T\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___mul__-7696624279617524538\n", "\n", "\n", "NDArray___mul__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-9914932780259612220\n", "\n", "\n", "NDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_scalar-14757501459592564217\n", "\n", "\n", "NDArray_scalar\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___getitem__-9658389681233211557\n", "\n", "\n", "TupleValue___getitem__(·, Int___init__(0))\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-16424482750509214731\n", "\n", "\n", "NDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "concat-430064524623572644\n", "\n", "\n", "concat(·, OptionalInt_none)\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___truediv__-3215265837560371319\n", "\n", "\n", "NDArray___truediv__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-11494903289568215254\n", "\n", "\n", "NDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-1818913068061409678\n", "\n", "\n", "TupleNDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-2205987174022554874\n", "\n", "\n", "NDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___gt__-15651908559655936539\n", "\n", "\n", "NDArray___gt__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_scalar-1143242824664700181\n", "\n", "\n", "NDArray_scalar(Value_float(Float___init__(0.0001)))\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___gt__-8664676620264668937\n", "\n", "\n", "NDArray___gt__\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-17539377729349800285\n", "\n", "\n", "TupleNDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___mul__-8440009558605893705\n", "\n", "\n", "NDArray___mul__\n", "\n", "\n", "\n", "\n", "\n", "\n", "asarray-17776165865978447989\n", "\n", "\n", "asarray(·, OptionalDType_none, OptionalBool_none)\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_scalar-3845340500482103568\n", "\n", "\n", "NDArray_scalar\n", "\n", "\n", "\n", "\n", "\n", "\n", "sum-1955564354691009820\n", "\n", "\n", "sum(·, OptionalIntOrTuple_none)\n", "\n", "\n", "\n", "\n", "\n", "\n", "astype-14592420363448682842\n", "\n", "\n", "astype(·, DType_int32)\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-13476223401931994896\n", "\n", "\n", "NDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_T-15484955256727723166\n", "\n", "\n", "NDArray_T\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_shape-15121139857639374588\n", "\n", "\n", "NDArray_shape\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___setitem__-7453141863274628760\n", "\n", "\n", "NDArray___setitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "mean-3476503888447580293\n", "\n", "\n", "mean(·, ·, FALSE)\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-10680274783444675613\n", "\n", "\n", "TupleNDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "astype-12468708834165933853\n", "\n", "\n", "astype(·, DType_int32)\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_scalar-15588902513610108474\n", "\n", "\n", "NDArray_scalar\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-1182067134106770624\n", "\n", "\n", "NDArray_index(·, TupleInt___init__(Int___init__(0)))\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-17758114586016463110\n", "\n", "\n", "NDArray___getitem__(·, IndexKey_int(Int___init__(0)))\n", "\n", "\n", "\n", "\n", "\n", "\n", "Value_float-15173113486080567242\n", "\n", "\n", "Value_float\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-3836913244690017957\n", "\n", "\n", "NDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-10045558824545728354\n", "\n", "\n", "TupleNDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-6343722845416298339\n", "\n", "\n", "NDArray___getitem__(·, IndexKey_int(Int___init__(0)))\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_scalar-12107377412216353484\n", "\n", "\n", "NDArray_scalar\n", "\n", "\n", "\n", "\n", "\n", "\n", "Value_float-6235596405652351031\n", "\n", "\n", "Value_float\n", "\n", "\n", "\n", "\n", "\n", "\n", "sum-1681433789052220133\n", "\n", "\n", "sum(·, OptionalIntOrTuple_none)\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___matmul__-13392291772433010205\n", "\n", "\n", "NDArray___matmul__\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-15957548086918070248\n", "\n", "\n", "TupleNDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "reshape-4112525690760736104\n", "\n", "\n", "reshape(·, TupleInt___init__(Int___init__(-1)), OptionalBool_none)\n", "\n", "\n", "\n", "\n", "\n", "\n", "assume_value_one_of-5323778840018127892\n", "\n", "\n", "assume_value_one_of\n", "\n", "\n", "\n", "\n", "\n", "\n", "assume_shape-8316602628326787375\n", "\n", "\n", "assume_shape(assume_dtype(NDArray_var("y"), DType_int64), ·)\n", "\n", "\n", "\n", "\n", "\n", "\n", "std-4851945112178408602\n", "\n", "\n", "std\n", "\n", "\n", "\n", "\n", "\n", "\n", "svd-2189404700831293460\n", "\n", "\n", "svd(·, FALSE)\n", "\n", "\n", "\n", "\n", "\n", "\n", "unique_counts-7742477628363861583\n", "\n", "\n", "unique_counts\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___mul__-8455018010728142919\n", "\n", "\n", "NDArray___mul__\n", "\n", "\n", "\n", "\n", "\n", "\n", "ndarray-sqrt-4416873412293684555\n", "\n", "\n", "ndarray-sqrt\n", "\n", "\n", "\n", "\n", "\n", "\n", "zeros-16505489609336576318\n", "\n", "\n", "zeros\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___add__-10752996994297486686\n", "\n", "\n", "TupleInt___add__\n", "\n", "\n", "\n", "\n", "\n", "\n", "OptionalDType_some-3429551472952562336\n", "\n", "\n", "OptionalDType_some\n", "\n", "\n", "\n", "\n", "\n", "\n", "OptionalDevice_some-5144327209428843504\n", "\n", "\n", "OptionalDevice_some\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_scalar-13770179520251441998\n", "\n", "\n", "NDArray_scalar\n", "\n", "\n", "\n", "\n", "\n", "\n", "Value_int-1870696621799859130\n", "\n", "\n", "Value_int\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___getitem__-9812641508136405718\n", "\n", "\n", "NDArray___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___mul__-4099386548708531027\n", "\n", "\n", "NDArray___mul__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_T-17147757364762811680\n", "\n", "\n", "NDArray_T\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___sub__-1374586120005010617\n", "\n", "\n", "NDArray___sub__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___sub__-8877293197236476153\n", "\n", "\n", "NDArray___sub__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___sub__-9558798608273926456\n", "\n", "\n", "NDArray___sub__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___matmul__-10968585808826125111\n", "\n", "\n", "NDArray___matmul__\n", "\n", "\n", "\n", "\n", "\n", "\n", "unique_inverse-7742477628363861583\n", "\n", "\n", "unique_inverse\n", "\n", "\n", "\n", "\n", "\n", "\n", "ndarray-sqrt-5404195351634806774\n", "\n", "\n", "ndarray-sqrt\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___mul__-3756686807776082277\n", "\n", "\n", "NDArray___mul__\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray___sub__-10430407918099810154\n", "\n", "\n", "NDArray___sub__\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___init__-12782857580910319779\n", "\n", "\n", "TupleNDArray___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___getitem__-13683004811263061306\n", "\n", "\n", "TupleNDArray___getitem__(·, Int___init__(0))\n", "\n", "\n", "\n", "\n", "\n", "\n", "unique_values-12782857580910319779\n", "\n", "\n", "unique_values\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_vector-18083105675662741245\n", "\n", "\n", "NDArray_vector\n", "\n", "\n", "\n", "\n", "\n", "\n", "unique_values-7742477628363861583\n", "\n", "\n", "unique_values\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_vector-467762655970733886\n", "\n", "\n", "NDArray_vector\n", "\n", "\n", "\n", "\n", "\n", "\n", "possible_values-13042725723116283049\n", "\n", "\n", "possible_values\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___getitem__-1353837537593392198\n", "\n", "\n", "TupleValue___getitem__(·, Int___init__(0))\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___init__-1870696621799859130\n", "\n", "\n", "TupleInt___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "OptionalInt_some-11224002729757616573\n", "\n", "\n", "OptionalInt_some\n", "\n", "\n", "\n", "\n", "\n", "\n", "OptionalInt_some-12938778466233897741\n", "\n", "\n", "OptionalInt_some\n", "\n", "\n", "\n", "\n", "\n", "\n", "OptionalInt_some-12990752094675090395\n", "\n", "\n", "OptionalInt_some\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-7586556743040283621-value\n", "\n", "\n", "(py-object -9223372036570011657 0)\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-7586556743040283621\n", "\n", "\n", "Int_to_py(Int___init__(0))\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-11951456526892775522\n", "\n", "\n", "Int_to_py\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-11951456526892775522-value\n", "\n", "\n", "(py-object -9223372036570011657 999998)\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-6079675520328773069\n", "\n", "\n", "Int_to_py\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-6079675520328773069-value\n", "\n", "\n", "(py-object -9223372036570011657 2)\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-103947256882385308\n", "\n", "\n", "Int_to_py\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-103947256882385308-value\n", "\n", "\n", "(py-object -9223372036570011657 20)\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-5092353580987650850-value\n", "\n", "\n", "(py-object -9223372036570011657 -2)\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-5092353580987650850\n", "\n", "\n", "Int_to_py(Int___init__(-1))\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-1870696621799859130\n", "\n", "\n", "Int_to_py\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-1870696621799859130-value\n", "\n", "\n", "(py-object -9223372036570011657 1000000)\n", "\n", "\n", "\n", "\n", "\n", "\n", "Boolean_to_py-155920885323577962-value\n", "\n", "\n", "(py-object 284764003 0)\n", "\n", "\n", "\n", "\n", "\n", "\n", "Boolean_to_py-155920885323577962\n", "\n", "\n", "Boolean_to_py(FALSE)\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-12938778466233897741\n", "\n", "\n", "Int_to_py\n", "\n", "\n", "\n", "\n", "\n", "\n", "Int_to_py-12938778466233897741-value\n", "\n", "\n", "(py-object -9223372036570011657 1)\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_shape-1714775736476281168\n", "\n", "\n", "NDArray_shape\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___init__-6079675520328773069\n", "\n", "\n", "TupleInt___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleInt___init__-12938778466233897741\n", "\n", "\n", "TupleInt___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___init__-14497633317386600947\n", "\n", "\n", "TupleNDArray___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleNDArray___init__-6131649148769965723\n", "\n", "\n", "TupleNDArray___init__\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___getitem__-7786309113067083429\n", "\n", "\n", "TupleValue___getitem__(·, Int___init__(0))\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-12579319251068649370\n", "\n", "\n", "NDArray_index(·, ALL_INDICES)\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-17067340853146132798\n", "\n", "\n", "NDArray_index(·, ALL_INDICES)\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___getitem__-4148863126349750477\n", "\n", "\n", "TupleValue___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "greater_zero-1143242824664700181-value\n", "\n", "\n", "()\n", "\n", "\n", "\n", "\n", "\n", "\n", "greater_zero-1143242824664700181\n", "\n", "\n", "greater_zero(Value_float(Float___init__(0.0001)))\n", "\n", "\n", "\n", "\n", "\n", "\n", "greater_zero-13770179520251441998\n", "\n", "\n", "greater_zero\n", "\n", "\n", "\n", "\n", "\n", "\n", "greater_zero-13770179520251441998-value\n", "\n", "\n", "()\n", "\n", "\n", "\n", "\n", "\n", "\n", "greater_zero-14757501459592564217\n", "\n", "\n", "greater_zero\n", "\n", "\n", "\n", "\n", "\n", "\n", "TupleValue___getitem__-14448359888109329694\n", "\n", "\n", "TupleValue___getitem__\n", "\n", "\n", "\n", "\n", "\n", "\n", "greater_zero-14757501459592564217-value\n", "\n", "\n", "()\n", "\n", "\n", "\n", "\n", "\n", "\n", "greater_zero-2598150418935018079\n", "\n", "\n", "greater_zero\n", "\n", "\n", "\n", "\n", "\n", "\n", "greater_zero-2598150418935018079-value\n", "\n", "\n", "()\n", "\n", "\n", "\n", "\n", "\n", "\n", "greater_zero-12107377412216353484\n", "\n", "\n", "greater_zero\n", "\n", "\n", "\n", "\n", "\n", "\n", "greater_zero-12107377412216353484-value\n", "\n", "\n", "()\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-3712217405396014230\n", "\n", "\n", "NDArray_index(·, TupleInt_EMPTY)\n", "\n", "\n", "\n", "\n", "\n", "\n", "Value_int-12938778466233897741\n", "\n", "\n", "Value_int\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-10864543514592368202\n", "\n", "\n", "NDArray_index\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-15769018209198649053\n", "\n", "\n", "NDArray_index(·, ALL_INDICES)\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-5822399466274154604\n", "\n", "\n", "NDArray_index(assume_dtype(NDArray_var("y"), DType_int64), ALL_INDICES)\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-7547271516962905184\n", "\n", "\n", "NDArray_index(NDArray_var("y"), ALL_INDICES)\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-6690955771313385503\n", "\n", "\n", "NDArray_index(·, TupleInt_EMPTY)\n", "\n", "\n", "\n", "\n", "\n", "\n", "NDArray_index-16788298149597563309\n", "\n", "\n", "NDArray_index(·, TupleInt___init__(Int___init__(0)))\n", "\n", "\n", "\n", "\n", "\n", "\n", "Value_int-7586556743040283621\n", "\n", "\n", "Value_int(Int___init__(0))\n", "\n", "\n", "\n", "\n", "\n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "egraph.display(n_inline_leaves=3, split_primitive_outputs=True)" ] }, { "cell_type": "markdown", "id": "21e4ee3a", "metadata": {}, "source": [ "## Translating for Numba\n", "\n", "We are getting closer to a form we could translate back to Numba, but we have to make a few changes. Numba doesn't\n", "support the `axis` keyword for `mean` or `std`, but it does support it for `sum`, so we have to translate all forms\n", "from one to the other, with a rule like this (defined in [`egglog.exp.array_api_numba`](https://github.com/egraphs-good/egglog-python/blob/main/python/egglog/exp/array_api_numba.py)):\n", "\n", "```python\n", "axis = OptionalIntOrTuple.some(IntOrTuple.int(i))\n", "rewrite(std(x, axis)).to(sqrt(mean(square(abs(x - mean(x, axis, keepdims=TRUE))), axis)))\n", "```\n", "\n", "We can run those additional rewrites now to get a new extracted version\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "9e79f88e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
_NDArray_1 = NDArray.var("X")\n",
              "assume_dtype(_NDArray_1, DType.float64)\n",
              "assume_shape(_NDArray_1, TupleInt(Int(1000000)) + TupleInt(Int(20)))\n",
              "assume_isfinite(_NDArray_1)\n",
              "_NDArray_2 = NDArray.var("y")\n",
              "assume_dtype(_NDArray_2, DType.int64)\n",
              "assume_shape(_NDArray_2, TupleInt(Int(1000000)))\n",
              "assume_value_one_of(_NDArray_2, TupleValue(Value.int(Int(0))) + TupleValue(Value.int(Int(1))))\n",
              "_NDArray_3 = astype(\n",
              "    NDArray.vector(TupleValue(sum(_NDArray_2 == NDArray.scalar(Value.int(Int(0)))).to_value()) + TupleValue(sum(_NDArray_2 == NDArray.scalar(Value.int(Int(1)))).to_value())),\n",
              "    DType.float64,\n",
              ") / NDArray.scalar(Value.float(Float(1000000.0)))\n",
              "_NDArray_4 = zeros(TupleInt(Int(2)) + TupleInt(Int(20)), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))\n",
              "_MultiAxisIndexKey_1 = MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice()))\n",
              "_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + _MultiAxisIndexKey_1)\n",
              "_NDArray_5 = _NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(0))))]\n",
              "_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))\n",
              "_NDArray_4[_IndexKey_1] = sum(_NDArray_5, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_5.shape[Int(0)]))\n",
              "_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(1))) + _MultiAxisIndexKey_1)\n",
              "_NDArray_6 = _NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(1))))]\n",
              "_NDArray_4[_IndexKey_2] = sum(_NDArray_6, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_6.shape[Int(0)]))\n",
              "_NDArray_7 = concat(TupleNDArray(_NDArray_5 - _NDArray_4[_IndexKey_1]) + TupleNDArray(_NDArray_6 - _NDArray_4[_IndexKey_2]), OptionalInt.some(Int(0)))\n",
              "_NDArray_8 = square(_NDArray_7 - expand_dims(sum(_NDArray_7, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_7.shape[Int(0)]))))\n",
              "_NDArray_9 = sqrt(sum(_NDArray_8, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_8.shape[Int(0)])))\n",
              "_NDArray_10 = copy(_NDArray_9)\n",
              "_NDArray_10[ndarray_index(_NDArray_9 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))\n",
              "_TupleNDArray_1 = svd(sqrt(NDArray.scalar(Value.float(Float(1.0) / Float.from_int(Int(999998))))) * (_NDArray_7 / _NDArray_10), FALSE)\n",
              "_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))\n",
              "_NDArray_11 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_10).T / _TupleNDArray_1[\n",
              "    Int(1)\n",
              "][IndexKey.slice(_Slice_1)]\n",
              "_TupleNDArray_2 = svd(\n",
              "    (sqrt((NDArray.scalar(Value.int(Int(1000000))) * _NDArray_3) * NDArray.scalar(Value.float(Float(1.0)))) * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T).T @ _NDArray_11, FALSE\n",
              ")\n",
              "(\n",
              "    (_NDArray_1 - (_NDArray_3 @ _NDArray_4))\n",
              "    @ (\n",
              "        _NDArray_11\n",
              "        @ _TupleNDArray_2[Int(2)].T[\n",
              "            IndexKey.multi_axis(\n",
              "                _MultiAxisIndexKey_1\n",
              "                + MultiAxisIndexKey(\n",
              "                    MultiAxisIndexKeyItem.slice(\n",
              "                        Slice(\n",
              "                            OptionalInt.none,\n",
              "                            OptionalInt.some(\n",
              "                                sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))\n",
              "                                .to_value()\n",
              "                                .to_int\n",
              "                            ),\n",
              "                        )\n",
              "                    )\n",
              "                )\n",
              "            )\n",
              "        ]\n",
              "    )\n",
              ")[IndexKey.multi_axis(_MultiAxisIndexKey_1 + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(1))))))]\n",
              "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}1} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{X}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}dtype}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}shape}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1000000}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{20}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}isfinite}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{y}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}dtype}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int64}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}shape}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1000000}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{assume\\PYZus{}value\\PYZus{}one\\PYZus{}of}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{TupleValue}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleValue}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{=} \\PY{n}{astype}\\PY{p}{(}\n", " \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{vector}\\PY{p}{(}\\PY{n}{TupleValue}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleValue}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,}\n", " \\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{,}\n", "\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1000000.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{=} \\PY{n}{zeros}\\PY{p}{(}\\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{20}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{OptionalDType}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)}\\PY{p}{,} \\PY{n}{OptionalDevice}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{o}{.}\\PY{n}{device}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1} \\PY{o}{=} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{Slice}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}5} \\PY{o}{=} \\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\n", "\\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1} \\PY{o}{=} \\PY{n}{OptionalIntOrTuple}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{IntOrTuple}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1}\\PY{p}{]} \\PY{o}{=} \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}6} \\PY{o}{=} \\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2}\\PY{p}{]} \\PY{o}{=} \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{=} \\PY{n}{concat}\\PY{p}{(}\\PY{n}{TupleNDArray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1}\\PY{p}{]}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleNDArray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}6} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}8} \\PY{o}{=} \\PY{n}{square}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{\\PYZhy{}} \\PY{n}{expand\\PYZus{}dims}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}9} \\PY{o}{=} \\PY{n}{sqrt}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}8}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}8}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}10} \\PY{o}{=} \\PY{n}{copy}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}9}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}10}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}9} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1} \\PY{o}{=} \\PY{n}{svd}\\PY{p}{(}\\PY{n}{sqrt}\\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)} \\PY{o}{/} \\PY{n}{Float}\\PY{o}{.}\\PY{n}{from\\PYZus{}int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{999998}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}10}\\PY{p}{)}\\PY{p}{,} \\PY{n}{FALSE}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}Slice\\PYZus{}1} \\PY{o}{=} \\PY{n}{Slice}\\PY{p}{(}\\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZgt{}} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{0.0001}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int32}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}int}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{\\PYZus{}NDArray\\PYZus{}11} \\PY{o}{=} \\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}10}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T} \\PY{o}{/} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\n", " \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\n", "\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{]}\n", "\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2} \\PY{o}{=} \\PY{n}{svd}\\PY{p}{(}\n", " \\PY{p}{(}\\PY{n}{sqrt}\\PY{p}{(}\\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1000000}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{)} \\PY{o}{*} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{\\PYZhy{}} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}11}\\PY{p}{,} \\PY{n}{FALSE}\n", "\\PY{p}{)}\n", "\\PY{p}{(}\n", " \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1} \\PY{o}{\\PYZhy{}} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{)}\\PY{p}{)}\n", " \\PY{o}{@} \\PY{p}{(}\n", " \\PY{n}{\\PYZus{}NDArray\\PYZus{}11}\n", " \\PY{o}{@} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{o}{.}\\PY{n}{T}\\PY{p}{[}\n", " \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\n", " \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\n", " \\PY{o}{+} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\n", " \\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\n", " \\PY{n}{Slice}\\PY{p}{(}\n", " \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,}\n", " \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\n", " \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZgt{}} \\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{0.0001}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int32}\\PY{p}{)}\\PY{p}{)}\n", " \\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{o}{.}\\PY{n}{to\\PYZus{}int}\n", " \\PY{p}{)}\\PY{p}{,}\n", " \\PY{p}{)}\n", " \\PY{p}{)}\n", " \\PY{p}{)}\n", " \\PY{p}{)}\n", " \\PY{p}{]}\n", " \\PY{p}{)}\n", "\\PY{p}{)}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1} \\PY{o}{+} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{Slice}\\PY{p}{(}\\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\n", "\\end{Verbatim}\n" ], "text/plain": [ "_NDArray_1 = NDArray.var(\"X\")\n", "assume_dtype(_NDArray_1, DType.float64)\n", "assume_shape(_NDArray_1, TupleInt(Int(1000000)) + TupleInt(Int(20)))\n", "assume_isfinite(_NDArray_1)\n", "_NDArray_2 = NDArray.var(\"y\")\n", "assume_dtype(_NDArray_2, DType.int64)\n", "assume_shape(_NDArray_2, TupleInt(Int(1000000)))\n", "assume_value_one_of(_NDArray_2, TupleValue(Value.int(Int(0))) + TupleValue(Value.int(Int(1))))\n", "_NDArray_3 = astype(\n", " NDArray.vector(TupleValue(sum(_NDArray_2 == NDArray.scalar(Value.int(Int(0)))).to_value()) + TupleValue(sum(_NDArray_2 == NDArray.scalar(Value.int(Int(1)))).to_value())),\n", " DType.float64,\n", ") / NDArray.scalar(Value.float(Float(1000000.0)))\n", "_NDArray_4 = zeros(TupleInt(Int(2)) + TupleInt(Int(20)), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))\n", "_MultiAxisIndexKey_1 = MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice()))\n", "_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + _MultiAxisIndexKey_1)\n", "_NDArray_5 = _NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(0))))]\n", "_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))\n", "_NDArray_4[_IndexKey_1] = sum(_NDArray_5, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_5.shape[Int(0)]))\n", "_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(1))) + _MultiAxisIndexKey_1)\n", "_NDArray_6 = _NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(1))))]\n", "_NDArray_4[_IndexKey_2] = sum(_NDArray_6, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_6.shape[Int(0)]))\n", "_NDArray_7 = concat(TupleNDArray(_NDArray_5 - _NDArray_4[_IndexKey_1]) + TupleNDArray(_NDArray_6 - _NDArray_4[_IndexKey_2]), OptionalInt.some(Int(0)))\n", "_NDArray_8 = square(_NDArray_7 - expand_dims(sum(_NDArray_7, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_7.shape[Int(0)]))))\n", "_NDArray_9 = sqrt(sum(_NDArray_8, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_8.shape[Int(0)])))\n", "_NDArray_10 = copy(_NDArray_9)\n", "_NDArray_10[ndarray_index(_NDArray_9 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))\n", "_TupleNDArray_1 = svd(sqrt(NDArray.scalar(Value.float(Float(1.0) / Float.from_int(Int(999998))))) * (_NDArray_7 / _NDArray_10), FALSE)\n", "_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))\n", "_NDArray_11 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_10).T / _TupleNDArray_1[\n", " Int(1)\n", "][IndexKey.slice(_Slice_1)]\n", "_TupleNDArray_2 = svd(\n", " (sqrt((NDArray.scalar(Value.int(Int(1000000))) * _NDArray_3) * NDArray.scalar(Value.float(Float(1.0)))) * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T).T @ _NDArray_11, FALSE\n", ")\n", "(\n", " (_NDArray_1 - (_NDArray_3 @ _NDArray_4))\n", " @ (\n", " _NDArray_11\n", " @ _TupleNDArray_2[Int(2)].T[\n", " IndexKey.multi_axis(\n", " _MultiAxisIndexKey_1\n", " + MultiAxisIndexKey(\n", " MultiAxisIndexKeyItem.slice(\n", " Slice(\n", " OptionalInt.none,\n", " OptionalInt.some(\n", " sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))\n", " .to_value()\n", " .to_int\n", " ),\n", " )\n", " )\n", " )\n", " )\n", " ]\n", " )\n", ")[IndexKey.multi_axis(_MultiAxisIndexKey_1 + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(1))))))]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from egglog.exp.array_api_numba import array_api_numba_module\n", "\n", "egraph = EGraph([array_api_numba_module])\n", "egraph.register(X_r2_optimized)\n", "egraph.run(10000)\n", "X_r2_numba = egraph.extract(X_r2_optimized)\n", "X_r2_numba" ] }, { "cell_type": "markdown", "id": "969490bb", "metadata": {}, "source": [ "## Compiling back to Python source\n", "\n", "Now we finally have a version that we could run with Numba! However, this isn't in NumPy code. What Numba needs\n", "is a function that uses `numpy`, not our typed dialect.\n", "\n", "So we use another module that provides a translation of all our methods into Python strings. The rules in it look like this:\n", "\n", "```python\n", "# the sqrt of an array should use the `np.sqrt` function and be assigned to its own variable, so it can be reused\n", "rewrite(ndarray_program(sqrt(x))).to((Program(\"np.sqrt(\") + ndarray_program(x) + \")\").assign())\n", "\n", "# To compile a setitem call, we first compile the source, assign it to a variable, then add an assignment statement\n", "mod_x = copy(x)\n", "mod_x[idx] = y\n", "assigned_x = ndarray_program(x).assign()\n", "yield rewrite(ndarray_program(mod_x)).to(\n", " assigned_x.statement(assigned_x + \"[\" + index_key_program(idx) + \"] = \" + ndarray_program(y))\n", ")\n", "```\n", "\n", "We pull in all those rewrite rules from the [`egglog.exp.array_api_program_gen` module](https://github.com/egraphs-good/egglog-python/blob/main/python/egglog/exp/array_api_program_gen.py).\n", "They depend on another module, [`egglog.exp.program_gen` module](https://github.com/egraphs-good/egglog-python/blob/main/python/egglog/exp/program_gen.py), which provides generic translations\n", "from expressions and statements into strings.\n", "\n", "We can run these rules to get out a Python function object:\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "3aeae673", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def __fn(X, y):\n", " assert X.dtype == np.dtype(np.float64)\n", " assert X.shape == (1000000, 20,)\n", " assert np.all(np.isfinite(X))\n", " assert y.dtype == np.dtype(np.int64)\n", " assert y.shape == (1000000,)\n", " assert set(np.unique(y)) == set((0, 1,))\n", " _0 = y == np.array(0)\n", " _1 = np.sum(_0)\n", " _2 = y == np.array(1)\n", " _3 = np.sum(_2)\n", " _4 = np.array((_1, _3,)).astype(np.dtype(np.float64))\n", " _5 = _4 / np.array(1000000.0)\n", " _6 = np.zeros((2, 20,), dtype=np.dtype(np.float64))\n", " _7 = np.sum(X[_0], axis=0)\n", " _8 = _7 / np.array(X[_0].shape[0])\n", " _6[0, :] = _8\n", " _9 = np.sum(X[_2], axis=0)\n", " _10 = _9 / np.array(X[_2].shape[0])\n", " _6[1, :] = _10\n", " _11 = _5 @ _6\n", " _12 = X - _11\n", " _13 = np.sqrt(np.array((1.0 / 999998)))\n", " _14 = X[_0] - _6[0, :]\n", " _15 = X[_2] - _6[1, :]\n", " _16 = np.concatenate((_14, _15,), axis=0)\n", " _17 = np.sum(_16, axis=0)\n", " _18 = _17 / np.array(_16.shape[0])\n", " _19 = np.expand_dims(_18, 0)\n", " _20 = _16 - _19\n", " _21 = np.square(_20)\n", " _22 = np.sum(_21, axis=0)\n", " _23 = _22 / np.array(_21.shape[0])\n", " _24 = np.sqrt(_23)\n", " _25 = _24 == np.array(0)\n", " _24[_25] = np.array(1.0)\n", " _26 = _16 / _24\n", " _27 = _13 * _26\n", " _28 = np.linalg.svd(_27, full_matrices=False)\n", " _29 = _28[1] > np.array(0.0001)\n", " _30 = _29.astype(np.dtype(np.int32))\n", " _31 = np.sum(_30)\n", " _32 = _28[2][:_31, :] / _24\n", " _33 = _32.T / _28[1][:_31]\n", " _34 = np.array(1000000) * _5\n", " _35 = _34 * np.array(1.0)\n", " _36 = np.sqrt(_35)\n", " _37 = _6 - _11\n", " _38 = _36 * _37.T\n", " _39 = _38.T @ _33\n", " _40 = np.linalg.svd(_39, full_matrices=False)\n", " _41 = np.array(0.0001) * _40[1][0]\n", " _42 = _40[1] > _41\n", " _43 = _42.astype(np.dtype(np.int32))\n", " _44 = np.sum(_43)\n", " _45 = _33 @ _40[2].T[:, :_44]\n", " _46 = _12 @ _45\n", " return _46[:, :1]\n", "\n" ] } ], "source": [ "from egglog.exp.array_api_program_gen import (\n", " ndarray_function_two,\n", " array_api_module_string,\n", ")\n", "\n", "egraph = EGraph([array_api_module_string])\n", "fn_program = ndarray_function_two(X_r2_numba, X_orig, y_orig)\n", "egraph.register(fn_program)\n", "egraph.run(10000)\n", "fn = egraph.load_object(egraph.extract(fn_program.py_object))\n", "import inspect\n", "\n", "print(inspect.getsource(fn))" ] }, { "cell_type": "markdown", "id": "6e0405c8", "metadata": {}, "source": [ "We can verify that the function gives the same result:\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "a807d66c", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np))" ] }, { "cell_type": "markdown", "id": "b2a3f1ed", "metadata": {}, "source": [ "Although it isn't the prettiest, we can see that it has only emitted each expression once, for common subexpression\n", "elimination, and preserves the \"imperative\" aspects of setitem.\n", "\n", "## Compiling to Numba\n", "\n", "Now we finally have a function we can run with numba:\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "39a69f23", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/egglog-9e61d62c-d17d-495b-b8db-f1eb3b38dcbb.py:56: NumbaPerformanceWarning: '@' is faster on contiguous arrays, called on (Array(float64, 2, 'C', False, aligned=True), Array(float64, 2, 'A', False, aligned=True))\n", " _45 = _33 @ _40[2].T[:, :_44]\n" ] } ], "source": [ "import numba\n", "import os\n", "\n", "fn_numba = numba.njit(fastmath=True)(fn)\n", "assert np.allclose(run_lda(X_np, y_np), fn_numba(X_np, y_np))" ] }, { "cell_type": "markdown", "id": "078d41b3", "metadata": {}, "source": [ "## Evaluating performance\n", "\n", "Let's see if it actually made anything quicker! Let's run a number of trials for the original function, our\n", "extracted version, and the optimized extracted version:\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "27a0cafc", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
originalextractedextracted numba
01.4829751.6093541.086486
11.4986561.5047041.145331
21.5009981.5572531.090356
31.5197321.5488001.122623
41.5004201.5011951.113089
51.5872111.5225181.176842
61.4994791.5268871.095296
71.6399101.5008591.086477
81.5251451.5592021.103662
91.5356011.4742991.074152
\n", "
" ], "text/plain": [ " original extracted extracted numba\n", "0 1.482975 1.609354 1.086486\n", "1 1.498656 1.504704 1.145331\n", "2 1.500998 1.557253 1.090356\n", "3 1.519732 1.548800 1.122623\n", "4 1.500420 1.501195 1.113089\n", "5 1.587211 1.522518 1.176842\n", "6 1.499479 1.526887 1.095296\n", "7 1.639910 1.500859 1.086477\n", "8 1.525145 1.559202 1.103662\n", "9 1.535601 1.474299 1.074152" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import timeit\n", "import pandas as pd\n", "\n", "stmts = {\n", " \"original\": \"run_lda(X_np, y_np)\",\n", " \"extracted\": \"fn(X_np, y_np)\",\n", " \"extracted numba\": \"fn_numba(X_np, y_np)\",\n", "}\n", "df = pd.DataFrame.from_dict(\n", " {name: timeit.repeat(stmt, globals=globals(), number=1, repeat=10) for name, stmt in stmts.items()}\n", ")\n", "\n", "df" ] }, { "cell_type": "code", "execution_count": 11, "id": "9488c513", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import seaborn as sns\n", "\n", "df_melt = pd.melt(df, var_name=\"function\", value_name=\"time\")\n", "_ = sns.catplot(data=df_melt, x=\"function\", y=\"time\", kind=\"swarm\")" ] }, { "cell_type": "markdown", "id": "83eab582", "metadata": {}, "source": [ "We see that the numba version is in fact faster, and the other two are about the same. It isn't significantly faster through,\n", "so we might want to run a profiler on the original function to see where most of the time is spent:\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "06d7777a", "metadata": {}, "outputs": [], "source": [ "%load_ext line_profiler" ] }, { "cell_type": "code", "execution_count": 13, "id": "f88942d6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Timer unit: 1e-09 s\n", "\n", "Total time: 1.41607 s\n", "File: /var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/egglog-9e61d62c-d17d-495b-b8db-f1eb3b38dcbb.py\n", "Function: __fn at line 1\n", "\n", "Line # Hits Time Per Hit % Time Line Contents\n", "==============================================================\n", " 1 def __fn(X, y):\n", " 2 1 13000.0 13000.0 0.0 assert X.dtype == np.dtype(np.float64)\n", " 3 1 2000.0 2000.0 0.0 assert X.shape == (1000000, 20,)\n", " 4 1 23813000.0 2e+07 1.7 assert np.all(np.isfinite(X))\n", " 5 1 11000.0 11000.0 0.0 assert y.dtype == np.dtype(np.int64)\n", " 6 1 14000.0 14000.0 0.0 assert y.shape == (1000000,)\n", " 7 1 23226000.0 2e+07 1.6 assert set(np.unique(y)) == set((0, 1,))\n", " 8 1 542000.0 542000.0 0.0 _0 = y == np.array(0)\n", " 9 1 488000.0 488000.0 0.0 _1 = np.sum(_0)\n", " 10 1 493000.0 493000.0 0.0 _2 = y == np.array(1)\n", " 11 1 454000.0 454000.0 0.0 _3 = np.sum(_2)\n", " 12 1 14000.0 14000.0 0.0 _4 = np.array((_1, _3,)).astype(np.dtype(np.float64))\n", " 13 1 9000.0 9000.0 0.0 _5 = _4 / np.array(1000000.0)\n", " 14 1 4000.0 4000.0 0.0 _6 = np.zeros((2, 20,), dtype=np.dtype(np.float64))\n", " 15 1 98376000.0 1e+08 6.9 _7 = np.sum(X[_0], axis=0)\n", " 16 1 38374000.0 4e+07 2.7 _8 = _7 / np.array(X[_0].shape[0])\n", " 17 1 6000.0 6000.0 0.0 _6[0, :] = _8\n", " 18 1 45697000.0 5e+07 3.2 _9 = np.sum(X[_2], axis=0)\n", " 19 1 35522000.0 4e+07 2.5 _10 = _9 / np.array(X[_2].shape[0])\n", " 20 1 6000.0 6000.0 0.0 _6[1, :] = _10\n", " 21 1 13000.0 13000.0 0.0 _11 = _5 @ _6\n", " 22 1 33768000.0 3e+07 2.4 _12 = X - _11\n", " 23 1 18000.0 18000.0 0.0 _13 = np.sqrt(np.array((1.0 / 999998)))\n", " 24 1 50544000.0 5e+07 3.6 _14 = X[_0] - _6[0, :]\n", " 25 1 55966000.0 6e+07 4.0 _15 = X[_2] - _6[1, :]\n", " 26 1 26138000.0 3e+07 1.8 _16 = np.concatenate((_14, _15,), axis=0)\n", " 27 1 23667000.0 2e+07 1.7 _17 = np.sum(_16, axis=0)\n", " 28 1 26000.0 26000.0 0.0 _18 = _17 / np.array(_16.shape[0])\n", " 29 1 45000.0 45000.0 0.0 _19 = np.expand_dims(_18, 0)\n", " 30 1 33604000.0 3e+07 2.4 _20 = _16 - _19\n", " 31 1 24774000.0 2e+07 1.7 _21 = np.square(_20)\n", " 32 1 21671000.0 2e+07 1.5 _22 = np.sum(_21, axis=0)\n", " 33 1 31000.0 31000.0 0.0 _23 = _22 / np.array(_21.shape[0])\n", " 34 1 4000.0 4000.0 0.0 _24 = np.sqrt(_23)\n", " 35 1 7000.0 7000.0 0.0 _25 = _24 == np.array(0)\n", " 36 1 3000.0 3000.0 0.0 _24[_25] = np.array(1.0)\n", " 37 1 32910000.0 3e+07 2.3 _26 = _16 / _24\n", " 38 1 24105000.0 2e+07 1.7 _27 = _13 * _26\n", " 39 1 814200000.0 8e+08 57.5 _28 = np.linalg.svd(_27, full_matrices=False)\n", " 40 1 23000.0 23000.0 0.0 _29 = _28[1] > np.array(0.0001)\n", " 41 1 10000.0 10000.0 0.0 _30 = _29.astype(np.dtype(np.int32))\n", " 42 1 63000.0 63000.0 0.0 _31 = np.sum(_30)\n", " 43 1 14000.0 14000.0 0.0 _32 = _28[2][:_31, :] / _24\n", " 44 1 7000.0 7000.0 0.0 _33 = _32.T / _28[1][:_31]\n", " 45 1 9000.0 9000.0 0.0 _34 = np.array(1000000) * _5\n", " 46 1 4000.0 4000.0 0.0 _35 = _34 * np.array(1.0)\n", " 47 1 3000.0 3000.0 0.0 _36 = np.sqrt(_35)\n", " 48 1 5000.0 5000.0 0.0 _37 = _6 - _11\n", " 49 1 4000.0 4000.0 0.0 _38 = _36 * _37.T\n", " 50 1 11000.0 11000.0 0.0 _39 = _38.T @ _33\n", " 51 1 70000.0 70000.0 0.0 _40 = np.linalg.svd(_39, full_matrices=False)\n", " 52 1 6000.0 6000.0 0.0 _41 = np.array(0.0001) * _40[1][0]\n", " 53 1 3000.0 3000.0 0.0 _42 = _40[1] > _41\n", " 54 1 4000.0 4000.0 0.0 _43 = _42.astype(np.dtype(np.int32))\n", " 55 1 18000.0 18000.0 0.0 _44 = np.sum(_43)\n", " 56 1 8000.0 8000.0 0.0 _45 = _33 @ _40[2].T[:, :_44]\n", " 57 1 7242000.0 7e+06 0.5 _46 = _12 @ _45\n", " 58 1 7000.0 7000.0 0.0 return _46[:, :1]" ] } ], "source": [ "%lprun -f fn fn(X_np, y_np)" ] }, { "cell_type": "markdown", "id": "7bf27fb6", "metadata": {}, "source": [ "We see that most of the time is spent in the SVD funciton, which [wouldn't be improved much by numba](https://github.com/numba/numba/issues/2423)\n", "since it is will call out to LAPACK, just like NumPy. The only savings would come from the other parts of the progarm,\n", "which can be inlined into\n", "\n", "## Conclusion\n", "\n", "To recap, in this tutorial we:\n", "\n", "1. Tried using a normal scikit-learn LDA function on some test data.\n", "2. Built up an abstract array and called it with that instead\n", "3. Optimized it and translated it to work with Numba\n", "4. Compiled it to a standalone Python funciton, which was optimized with Numba\n", "5. Verified that this improved our performance with this test data.\n", "\n", "The implementation of the Array API provided here is experimental, and not complete, but at least serves to show it is\n", "possible to build an API like that with `egglog`.\n" ] } ], "metadata": { "file_format": "mystnb", "kernelspec": { "display_name": "egglog-python", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "mystnb": { "execution_mode": "off" } }, "nbformat": 4, "nbformat_minor": 5 }