_NDArray_1 = NDArray . var ( "X" ) \n",
"assume_dtype ( _NDArray_1 , DType . float64 ) \n",
"assume_shape ( _NDArray_1 , TupleInt ( Int ( 1000000 )) + TupleInt ( Int ( 20 ))) \n",
"assume_isfinite ( _NDArray_1 ) \n",
"_NDArray_2 = NDArray . var ( "y" ) \n",
"assume_dtype ( _NDArray_2 , DType . int64 ) \n",
"assume_shape ( _NDArray_2 , TupleInt ( Int ( 1000000 ))) \n",
"assume_value_one_of ( _NDArray_2 , TupleValue ( Value . int ( Int ( 0 ))) + TupleValue ( Value . int ( Int ( 1 )))) \n",
"_NDArray_3 = asarray ( reshape ( asarray ( _NDArray_2 ), TupleInt ( Int ( - 1 )))) \n",
"_NDArray_4 = astype ( unique_counts ( _NDArray_3 )[ Int ( 1 )], asarray ( _NDArray_1 ) . dtype ) / NDArray . scalar ( Value . float ( Float ( 1000000.0 ))) \n",
"_NDArray_5 = zeros ( \n",
" TupleInt ( unique_inverse ( _NDArray_3 )[ Int ( 0 )] . shape [ Int ( 0 )]) + TupleInt ( asarray ( _NDArray_1 ) . shape [ Int ( 1 )]), \n",
" OptionalDType . some ( asarray ( _NDArray_1 ) . dtype ), \n",
" OptionalDevice . some ( asarray ( _NDArray_1 ) . device ), \n",
") \n",
"_MultiAxisIndexKey_1 = MultiAxisIndexKey ( MultiAxisIndexKeyItem . slice ( Slice ())) \n",
"_IndexKey_1 = IndexKey . multi_axis ( MultiAxisIndexKey ( MultiAxisIndexKeyItem . int ( Int ( 0 ))) + _MultiAxisIndexKey_1 ) \n",
"_OptionalIntOrTuple_1 = OptionalIntOrTuple . some ( IntOrTuple . int ( Int ( 0 ))) \n",
"_NDArray_5 [ _IndexKey_1 ] = mean ( asarray ( _NDArray_1 )[ ndarray_index ( unique_inverse ( _NDArray_3 )[ Int ( 1 )] == NDArray . scalar ( Value . int ( Int ( 0 ))))], _OptionalIntOrTuple_1 ) \n",
"_IndexKey_2 = IndexKey . multi_axis ( MultiAxisIndexKey ( MultiAxisIndexKeyItem . int ( Int ( 1 ))) + _MultiAxisIndexKey_1 ) \n",
"_NDArray_5 [ _IndexKey_2 ] = mean ( asarray ( _NDArray_1 )[ ndarray_index ( unique_inverse ( _NDArray_3 )[ Int ( 1 )] == NDArray . scalar ( Value . int ( Int ( 1 ))))], _OptionalIntOrTuple_1 ) \n",
"_NDArray_6 = unique_values ( concat ( TupleNDArray ( unique_values ( asarray ( _NDArray_3 ))))) \n",
"_NDArray_7 = concat ( \n",
" TupleNDArray ( asarray ( _NDArray_1 )[ ndarray_index ( _NDArray_3 == _NDArray_6 [ IndexKey . int ( Int ( 0 ))])] - _NDArray_5 [ _IndexKey_1 ]) \n",
" + TupleNDArray ( asarray ( _NDArray_1 )[ ndarray_index ( _NDArray_3 == _NDArray_6 [ IndexKey . int ( Int ( 1 ))])] - _NDArray_5 [ _IndexKey_2 ]), \n",
" OptionalInt . some ( Int ( 0 )), \n",
") \n",
"_NDArray_8 = std ( _NDArray_7 , _OptionalIntOrTuple_1 ) \n",
"_NDArray_8 [ ndarray_index ( std ( _NDArray_7 , _OptionalIntOrTuple_1 ) == NDArray . scalar ( Value . int ( Int ( 0 ))))] = NDArray . scalar ( Value . float ( Float ( 1.0 ))) \n",
"_TupleNDArray_1 = svd ( \n",
" sqrt ( asarray ( NDArray . scalar ( Value . float ( Float ( 1.0 ) / Float . from_int ( asarray ( _NDArray_1 ) . shape [ Int ( 0 )] - _NDArray_6 . shape [ Int ( 0 )]))))) * ( _NDArray_7 / _NDArray_8 ), FALSE \n",
") \n",
"_Slice_1 = Slice ( OptionalInt . none , OptionalInt . some ( sum ( astype ( _TupleNDArray_1 [ Int ( 1 )] > NDArray . scalar ( Value . float ( Float ( 0.0001 ))), DType . int32 )) . to_value () . to_int )) \n",
"_NDArray_9 = ( _TupleNDArray_1 [ Int ( 2 )][ IndexKey . multi_axis ( MultiAxisIndexKey ( MultiAxisIndexKeyItem . slice ( _Slice_1 )) + _MultiAxisIndexKey_1 )] / _NDArray_8 ) . T / _TupleNDArray_1 [ \n",
" Int ( 1 ) \n",
"][ IndexKey . slice ( _Slice_1 )] \n",
"_TupleNDArray_2 = svd ( \n",
" ( \n",
" sqrt ( \n",
" ( NDArray . scalar ( Value . int ( asarray ( _NDArray_1 ) . shape [ Int ( 0 )])) * _NDArray_4 ) \n",
" * NDArray . scalar ( Value . float ( Float ( 1.0 ) / Float . from_int ( _NDArray_6 . shape [ Int ( 0 )] - Int ( 1 )))) \n",
" ) \n",
" * ( _NDArray_5 - ( _NDArray_4 @ _NDArray_5 )) . T \n",
" ) . T \n",
" @ _NDArray_9 , \n",
" FALSE , \n",
") \n",
"( \n",
" ( asarray ( _NDArray_1 ) - ( _NDArray_4 @ _NDArray_5 )) \n",
" @ ( \n",
" _NDArray_9 \n",
" @ _TupleNDArray_2 [ Int ( 2 )] . T [ \n",
" IndexKey . multi_axis ( \n",
" _MultiAxisIndexKey_1 \n",
" + MultiAxisIndexKey ( \n",
" MultiAxisIndexKeyItem . slice ( \n",
" Slice ( \n",
" OptionalInt . none , \n",
" OptionalInt . some ( \n",
" sum ( astype ( _TupleNDArray_2 [ Int ( 1 )] > ( NDArray . scalar ( Value . float ( Float ( 0.0001 ))) * _TupleNDArray_2 [ Int ( 1 )][ IndexKey . int ( Int ( 0 ))]), DType . int32 )) \n",
" . to_value () \n",
" . to_int \n",
" ), \n",
" ) \n",
" ) \n",
" ) \n",
" ) \n",
" ] \n",
" ) \n",
")[ IndexKey . multi_axis ( _MultiAxisIndexKey_1 + MultiAxisIndexKey ( MultiAxisIndexKeyItem . slice ( Slice ( OptionalInt . none , OptionalInt . some ( _NDArray_6 . shape [ Int ( 0 )] - Int ( 1 ))))))] \n",
" \n"
],
"text/latex": [
"\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}1} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{X}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}dtype}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}shape}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1000000}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{20}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}isfinite}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{y}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}dtype}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int64}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}shape}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1000000}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}value\\PYZus{}one\\PYZus{}of}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{TupleValue}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleValue}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{=} \\PY{n}{asarray}\\PY{p}{(}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{)}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{=} \\PY{n}{astype}\\PY{p}{(}\\PY{n}{unique\\PYZus{}counts}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{,} \\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{dtype}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1000000.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}5} \\PY{o}{=} \\PY{n}{zeros}\\PY{p}{(}\n",
" \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{unique\\PYZus{}inverse}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{OptionalDType}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{dtype}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{OptionalDevice}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{device}\\PY{p}{)}\\PY{p}{,}\n",
"\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1} \\PY{o}{=} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{Slice}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1} \\PY{o}{=} \\PY{n}{OptionalIntOrTuple}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{IntOrTuple}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1}\\PY{p}{]} \\PY{o}{=} \\PY{n}{mean}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{unique\\PYZus{}inverse}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2}\\PY{p}{]} \\PY{o}{=} \\PY{n}{mean}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{unique\\PYZus{}inverse}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}6} \\PY{o}{=} \\PY{n}{unique\\PYZus{}values}\\PY{p}{(}\\PY{n}{concat}\\PY{p}{(}\\PY{n}{TupleNDArray}\\PY{p}{(}\\PY{n}{unique\\PYZus{}values}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{=} \\PY{n}{concat}\\PY{p}{(}\n",
" \\PY{n}{TupleNDArray}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{==} \\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1}\\PY{p}{]}\\PY{p}{)}\n",
" \\PY{o}{+} \\PY{n}{TupleNDArray}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{==} \\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,}\n",
"\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}8} \\PY{o}{=} \\PY{n}{std}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}8}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{std}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1} \\PY{o}{=} \\PY{n}{svd}\\PY{p}{(}\n",
" \\PY{n}{sqrt}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)} \\PY{o}{/} \\PY{n}{Float}\\PY{o}{.}\\PY{n}{from\\PYZus{}int}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}8}\\PY{p}{)}\\PY{p}{,} \\PY{n}{FALSE}\n",
"\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}Slice\\PYZus{}1} \\PY{o}{=} \\PY{n}{Slice}\\PY{p}{(}\\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZgt{}} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{0.0001}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int32}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}int}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}9} \\PY{o}{=} \\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}8}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T} \\PY{o}{/} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\n",
" \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\n",
"\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{]}\n",
"\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2} \\PY{o}{=} \\PY{n}{svd}\\PY{p}{(}\n",
" \\PY{p}{(}\n",
" \\PY{n}{sqrt}\\PY{p}{(}\n",
" \\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{)}\n",
" \\PY{o}{*} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)} \\PY{o}{/} \\PY{n}{Float}\\PY{o}{.}\\PY{n}{from\\PYZus{}int}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{o}{*} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5} \\PY{o}{\\PYZhy{}} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T}\n",
" \\PY{p}{)}\\PY{o}{.}\\PY{n}{T}\n",
" \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}9}\\PY{p}{,}\n",
" \\PY{n}{FALSE}\\PY{p}{,}\n",
"\\PY{p}{)}\n",
"\\PY{p}{(}\n",
" \\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)} \\PY{o}{\\PYZhy{}} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{)}\\PY{p}{)}\n",
" \\PY{o}{@} \\PY{p}{(}\n",
" \\PY{n}{\\PYZus{}NDArray\\PYZus{}9}\n",
" \\PY{o}{@} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{o}{.}\\PY{n}{T}\\PY{p}{[}\n",
" \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\n",
" \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\n",
" \\PY{o}{+} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\n",
" \\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\n",
" \\PY{n}{Slice}\\PY{p}{(}\n",
" \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,}\n",
" \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\n",
" \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZgt{}} \\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{0.0001}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int32}\\PY{p}{)}\\PY{p}{)}\n",
" \\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\n",
" \\PY{o}{.}\\PY{n}{to\\PYZus{}int}\n",
" \\PY{p}{)}\\PY{p}{,}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{p}{]}\n",
" \\PY{p}{)}\n",
"\\PY{p}{)}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1} \\PY{o}{+} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{Slice}\\PY{p}{(}\\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\n",
"\\end{Verbatim}\n"
],
"text/plain": [
"_NDArray_1 = NDArray.var(\"X\")\n",
"assume_dtype(_NDArray_1, DType.float64)\n",
"assume_shape(_NDArray_1, TupleInt(Int(1000000)) + TupleInt(Int(20)))\n",
"assume_isfinite(_NDArray_1)\n",
"_NDArray_2 = NDArray.var(\"y\")\n",
"assume_dtype(_NDArray_2, DType.int64)\n",
"assume_shape(_NDArray_2, TupleInt(Int(1000000)))\n",
"assume_value_one_of(_NDArray_2, TupleValue(Value.int(Int(0))) + TupleValue(Value.int(Int(1))))\n",
"_NDArray_3 = asarray(reshape(asarray(_NDArray_2), TupleInt(Int(-1))))\n",
"_NDArray_4 = astype(unique_counts(_NDArray_3)[Int(1)], asarray(_NDArray_1).dtype) / NDArray.scalar(Value.float(Float(1000000.0)))\n",
"_NDArray_5 = zeros(\n",
" TupleInt(unique_inverse(_NDArray_3)[Int(0)].shape[Int(0)]) + TupleInt(asarray(_NDArray_1).shape[Int(1)]),\n",
" OptionalDType.some(asarray(_NDArray_1).dtype),\n",
" OptionalDevice.some(asarray(_NDArray_1).device),\n",
")\n",
"_MultiAxisIndexKey_1 = MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice()))\n",
"_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + _MultiAxisIndexKey_1)\n",
"_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))\n",
"_NDArray_5[_IndexKey_1] = mean(asarray(_NDArray_1)[ndarray_index(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(0))))], _OptionalIntOrTuple_1)\n",
"_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(1))) + _MultiAxisIndexKey_1)\n",
"_NDArray_5[_IndexKey_2] = mean(asarray(_NDArray_1)[ndarray_index(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(1))))], _OptionalIntOrTuple_1)\n",
"_NDArray_6 = unique_values(concat(TupleNDArray(unique_values(asarray(_NDArray_3)))))\n",
"_NDArray_7 = concat(\n",
" TupleNDArray(asarray(_NDArray_1)[ndarray_index(_NDArray_3 == _NDArray_6[IndexKey.int(Int(0))])] - _NDArray_5[_IndexKey_1])\n",
" + TupleNDArray(asarray(_NDArray_1)[ndarray_index(_NDArray_3 == _NDArray_6[IndexKey.int(Int(1))])] - _NDArray_5[_IndexKey_2]),\n",
" OptionalInt.some(Int(0)),\n",
")\n",
"_NDArray_8 = std(_NDArray_7, _OptionalIntOrTuple_1)\n",
"_NDArray_8[ndarray_index(std(_NDArray_7, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))\n",
"_TupleNDArray_1 = svd(\n",
" sqrt(asarray(NDArray.scalar(Value.float(Float(1.0) / Float.from_int(asarray(_NDArray_1).shape[Int(0)] - _NDArray_6.shape[Int(0)]))))) * (_NDArray_7 / _NDArray_8), FALSE\n",
")\n",
"_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))\n",
"_NDArray_9 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_8).T / _TupleNDArray_1[\n",
" Int(1)\n",
"][IndexKey.slice(_Slice_1)]\n",
"_TupleNDArray_2 = svd(\n",
" (\n",
" sqrt(\n",
" (NDArray.scalar(Value.int(asarray(_NDArray_1).shape[Int(0)])) * _NDArray_4)\n",
" * NDArray.scalar(Value.float(Float(1.0) / Float.from_int(_NDArray_6.shape[Int(0)] - Int(1))))\n",
" )\n",
" * (_NDArray_5 - (_NDArray_4 @ _NDArray_5)).T\n",
" ).T\n",
" @ _NDArray_9,\n",
" FALSE,\n",
")\n",
"(\n",
" (asarray(_NDArray_1) - (_NDArray_4 @ _NDArray_5))\n",
" @ (\n",
" _NDArray_9\n",
" @ _TupleNDArray_2[Int(2)].T[\n",
" IndexKey.multi_axis(\n",
" _MultiAxisIndexKey_1\n",
" + MultiAxisIndexKey(\n",
" MultiAxisIndexKeyItem.slice(\n",
" Slice(\n",
" OptionalInt.none,\n",
" OptionalInt.some(\n",
" sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))\n",
" .to_value()\n",
" .to_int\n",
" ),\n",
" )\n",
" )\n",
" )\n",
" )\n",
" ]\n",
" )\n",
")[IndexKey.multi_axis(_MultiAxisIndexKey_1 + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(_NDArray_6.shape[Int(0)] - Int(1))))))]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from egglog import EGraph\n",
"from egglog.exp.array_api import array_api_module\n",
"\n",
"with EGraph([array_api_module]) as egraph:\n",
" X_r2 = run_lda(X_arr, y_arr)\n",
" egraph.display(n_inline_leaves=3, split_primitive_outputs=True)\n",
"X_r2"
]
},
{
"cell_type": "markdown",
"id": "580da17b",
"metadata": {},
"source": [
"We now have extracted out a program which is semantically equivalent to the original call! One thing you might notice\n",
"is that the expression has more types than customary NumPy code. Every object is lifted into a strongly typed `egglog`\n",
"class. This is so that when we run optimizations, we know the types of all the objects. It still is compatible with\n",
"normal Python objects, but they are [converted](type-promotion) when they are passed as argument.\n",
"\n",
"## Optimizing our result\n",
"\n",
"Now that we have the an expression, we can run our rewrite rules to \"optimize\" it, extracting out the lowest cost\n",
"(smallest) expression afterword:\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4d3cd4f3",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"_NDArray_1 = NDArray . var ( "X" ) \n",
"assume_dtype ( _NDArray_1 , DType . float64 ) \n",
"assume_shape ( _NDArray_1 , TupleInt ( Int ( 1000000 )) + TupleInt ( Int ( 20 ))) \n",
"assume_isfinite ( _NDArray_1 ) \n",
"_NDArray_2 = NDArray . var ( "y" ) \n",
"assume_dtype ( _NDArray_2 , DType . int64 ) \n",
"assume_shape ( _NDArray_2 , TupleInt ( Int ( 1000000 ))) \n",
"assume_value_one_of ( _NDArray_2 , TupleValue ( Value . int ( Int ( 0 ))) + TupleValue ( Value . int ( Int ( 1 )))) \n",
"_NDArray_3 = astype ( unique_counts ( _NDArray_2 )[ Int ( 1 )], DType . float64 ) / NDArray . scalar ( Value . float ( Float ( 1000000.0 ))) \n",
"_NDArray_4 = zeros ( TupleInt ( Int ( 2 )) + TupleInt ( Int ( 20 )), OptionalDType . some ( DType . float64 ), OptionalDevice . some ( _NDArray_1 . device )) \n",
"_MultiAxisIndexKey_1 = MultiAxisIndexKey ( MultiAxisIndexKeyItem . slice ( Slice ())) \n",
"_IndexKey_1 = IndexKey . multi_axis ( MultiAxisIndexKey ( MultiAxisIndexKeyItem . int ( Int ( 0 ))) + _MultiAxisIndexKey_1 ) \n",
"_OptionalIntOrTuple_1 = OptionalIntOrTuple . some ( IntOrTuple . int ( Int ( 0 ))) \n",
"_NDArray_4 [ _IndexKey_1 ] = mean ( _NDArray_1 [ ndarray_index ( unique_inverse ( _NDArray_2 )[ Int ( 1 )] == NDArray . scalar ( Value . int ( Int ( 0 ))))], _OptionalIntOrTuple_1 ) \n",
"_IndexKey_2 = IndexKey . multi_axis ( MultiAxisIndexKey ( MultiAxisIndexKeyItem . int ( Int ( 1 ))) + _MultiAxisIndexKey_1 ) \n",
"_NDArray_4 [ _IndexKey_2 ] = mean ( _NDArray_1 [ ndarray_index ( unique_inverse ( _NDArray_2 )[ Int ( 1 )] == NDArray . scalar ( Value . int ( Int ( 1 ))))], _OptionalIntOrTuple_1 ) \n",
"_NDArray_5 = concat ( \n",
" TupleNDArray ( _NDArray_1 [ ndarray_index ( _NDArray_2 == NDArray . scalar ( Value . int ( Int ( 0 ))))] - _NDArray_4 [ _IndexKey_1 ]) \n",
" + TupleNDArray ( _NDArray_1 [ ndarray_index ( _NDArray_2 == NDArray . scalar ( Value . int ( Int ( 1 ))))] - _NDArray_4 [ _IndexKey_2 ]), \n",
" OptionalInt . some ( Int ( 0 )), \n",
") \n",
"_NDArray_6 = std ( _NDArray_5 , _OptionalIntOrTuple_1 ) \n",
"_NDArray_6 [ ndarray_index ( std ( _NDArray_5 , _OptionalIntOrTuple_1 ) == NDArray . scalar ( Value . int ( Int ( 0 ))))] = NDArray . scalar ( Value . float ( Float ( 1.0 ))) \n",
"_TupleNDArray_1 = svd ( sqrt ( NDArray . scalar ( Value . float ( Float ( 1.0 ) / Float . from_int ( Int ( 999998 ))))) * ( _NDArray_5 / _NDArray_6 ), FALSE ) \n",
"_Slice_1 = Slice ( OptionalInt . none , OptionalInt . some ( sum ( astype ( _TupleNDArray_1 [ Int ( 1 )] > NDArray . scalar ( Value . float ( Float ( 0.0001 ))), DType . int32 )) . to_value () . to_int )) \n",
"_NDArray_7 = ( _TupleNDArray_1 [ Int ( 2 )][ IndexKey . multi_axis ( MultiAxisIndexKey ( MultiAxisIndexKeyItem . slice ( _Slice_1 )) + _MultiAxisIndexKey_1 )] / _NDArray_6 ) . T / _TupleNDArray_1 [ \n",
" Int ( 1 ) \n",
"][ IndexKey . slice ( _Slice_1 )] \n",
"_TupleNDArray_2 = svd ( \n",
" ( sqrt (( NDArray . scalar ( Value . int ( Int ( 1000000 ))) * _NDArray_3 ) * NDArray . scalar ( Value . float ( Float ( 1.0 )))) * ( _NDArray_4 - ( _NDArray_3 @ _NDArray_4 )) . T ) . T @ _NDArray_7 , FALSE \n",
") \n",
"( \n",
" ( _NDArray_1 - ( _NDArray_3 @ _NDArray_4 )) \n",
" @ ( \n",
" _NDArray_7 \n",
" @ _TupleNDArray_2 [ Int ( 2 )] . T [ \n",
" IndexKey . multi_axis ( \n",
" _MultiAxisIndexKey_1 \n",
" + MultiAxisIndexKey ( \n",
" MultiAxisIndexKeyItem . slice ( \n",
" Slice ( \n",
" OptionalInt . none , \n",
" OptionalInt . some ( \n",
" sum ( astype ( _TupleNDArray_2 [ Int ( 1 )] > ( NDArray . scalar ( Value . float ( Float ( 0.0001 ))) * _TupleNDArray_2 [ Int ( 1 )][ IndexKey . int ( Int ( 0 ))]), DType . int32 )) \n",
" . to_value () \n",
" . to_int \n",
" ), \n",
" ) \n",
" ) \n",
" ) \n",
" ) \n",
" ] \n",
" ) \n",
")[ IndexKey . multi_axis ( _MultiAxisIndexKey_1 + MultiAxisIndexKey ( MultiAxisIndexKeyItem . slice ( Slice ( OptionalInt . none , OptionalInt . some ( Int ( 1 ))))))] \n",
" \n"
],
"text/latex": [
"\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}1} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{X}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}dtype}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}shape}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1000000}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{20}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}isfinite}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{y}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}dtype}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int64}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}shape}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1000000}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}value\\PYZus{}one\\PYZus{}of}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{TupleValue}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleValue}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{=} \\PY{n}{astype}\\PY{p}{(}\\PY{n}{unique\\PYZus{}counts}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1000000.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{=} \\PY{n}{zeros}\\PY{p}{(}\\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{20}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{OptionalDType}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)}\\PY{p}{,} \\PY{n}{OptionalDevice}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{o}{.}\\PY{n}{device}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1} \\PY{o}{=} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{Slice}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1} \\PY{o}{=} \\PY{n}{OptionalIntOrTuple}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{IntOrTuple}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1}\\PY{p}{]} \\PY{o}{=} \\PY{n}{mean}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{unique\\PYZus{}inverse}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2}\\PY{p}{]} \\PY{o}{=} \\PY{n}{mean}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{unique\\PYZus{}inverse}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}5} \\PY{o}{=} \\PY{n}{concat}\\PY{p}{(}\n",
" \\PY{n}{TupleNDArray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1}\\PY{p}{]}\\PY{p}{)}\n",
" \\PY{o}{+} \\PY{n}{TupleNDArray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,}\n",
"\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}6} \\PY{o}{=} \\PY{n}{std}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{std}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1} \\PY{o}{=} \\PY{n}{svd}\\PY{p}{(}\\PY{n}{sqrt}\\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)} \\PY{o}{/} \\PY{n}{Float}\\PY{o}{.}\\PY{n}{from\\PYZus{}int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{999998}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5} \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{p}{)}\\PY{p}{,} \\PY{n}{FALSE}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}Slice\\PYZus{}1} \\PY{o}{=} \\PY{n}{Slice}\\PY{p}{(}\\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZgt{}} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{0.0001}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int32}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}int}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{=} \\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T} \\PY{o}{/} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\n",
" \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\n",
"\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{]}\n",
"\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2} \\PY{o}{=} \\PY{n}{svd}\\PY{p}{(}\n",
" \\PY{p}{(}\\PY{n}{sqrt}\\PY{p}{(}\\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1000000}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{)} \\PY{o}{*} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{\\PYZhy{}} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}7}\\PY{p}{,} \\PY{n}{FALSE}\n",
"\\PY{p}{)}\n",
"\\PY{p}{(}\n",
" \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1} \\PY{o}{\\PYZhy{}} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{)}\\PY{p}{)}\n",
" \\PY{o}{@} \\PY{p}{(}\n",
" \\PY{n}{\\PYZus{}NDArray\\PYZus{}7}\n",
" \\PY{o}{@} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{o}{.}\\PY{n}{T}\\PY{p}{[}\n",
" \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\n",
" \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\n",
" \\PY{o}{+} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\n",
" \\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\n",
" \\PY{n}{Slice}\\PY{p}{(}\n",
" \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,}\n",
" \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\n",
" \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZgt{}} \\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{0.0001}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int32}\\PY{p}{)}\\PY{p}{)}\n",
" \\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\n",
" \\PY{o}{.}\\PY{n}{to\\PYZus{}int}\n",
" \\PY{p}{)}\\PY{p}{,}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{p}{]}\n",
" \\PY{p}{)}\n",
"\\PY{p}{)}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1} \\PY{o}{+} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{Slice}\\PY{p}{(}\\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\n",
"\\end{Verbatim}\n"
],
"text/plain": [
"_NDArray_1 = NDArray.var(\"X\")\n",
"assume_dtype(_NDArray_1, DType.float64)\n",
"assume_shape(_NDArray_1, TupleInt(Int(1000000)) + TupleInt(Int(20)))\n",
"assume_isfinite(_NDArray_1)\n",
"_NDArray_2 = NDArray.var(\"y\")\n",
"assume_dtype(_NDArray_2, DType.int64)\n",
"assume_shape(_NDArray_2, TupleInt(Int(1000000)))\n",
"assume_value_one_of(_NDArray_2, TupleValue(Value.int(Int(0))) + TupleValue(Value.int(Int(1))))\n",
"_NDArray_3 = astype(unique_counts(_NDArray_2)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(1000000.0)))\n",
"_NDArray_4 = zeros(TupleInt(Int(2)) + TupleInt(Int(20)), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))\n",
"_MultiAxisIndexKey_1 = MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice()))\n",
"_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + _MultiAxisIndexKey_1)\n",
"_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))\n",
"_NDArray_4[_IndexKey_1] = mean(_NDArray_1[ndarray_index(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(0))))], _OptionalIntOrTuple_1)\n",
"_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(1))) + _MultiAxisIndexKey_1)\n",
"_NDArray_4[_IndexKey_2] = mean(_NDArray_1[ndarray_index(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(1))))], _OptionalIntOrTuple_1)\n",
"_NDArray_5 = concat(\n",
" TupleNDArray(_NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(0))))] - _NDArray_4[_IndexKey_1])\n",
" + TupleNDArray(_NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(1))))] - _NDArray_4[_IndexKey_2]),\n",
" OptionalInt.some(Int(0)),\n",
")\n",
"_NDArray_6 = std(_NDArray_5, _OptionalIntOrTuple_1)\n",
"_NDArray_6[ndarray_index(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))\n",
"_TupleNDArray_1 = svd(sqrt(NDArray.scalar(Value.float(Float(1.0) / Float.from_int(Int(999998))))) * (_NDArray_5 / _NDArray_6), FALSE)\n",
"_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))\n",
"_NDArray_7 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_6).T / _TupleNDArray_1[\n",
" Int(1)\n",
"][IndexKey.slice(_Slice_1)]\n",
"_TupleNDArray_2 = svd(\n",
" (sqrt((NDArray.scalar(Value.int(Int(1000000))) * _NDArray_3) * NDArray.scalar(Value.float(Float(1.0)))) * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T).T @ _NDArray_7, FALSE\n",
")\n",
"(\n",
" (_NDArray_1 - (_NDArray_3 @ _NDArray_4))\n",
" @ (\n",
" _NDArray_7\n",
" @ _TupleNDArray_2[Int(2)].T[\n",
" IndexKey.multi_axis(\n",
" _MultiAxisIndexKey_1\n",
" + MultiAxisIndexKey(\n",
" MultiAxisIndexKeyItem.slice(\n",
" Slice(\n",
" OptionalInt.none,\n",
" OptionalInt.some(\n",
" sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))\n",
" .to_value()\n",
" .to_int\n",
" ),\n",
" )\n",
" )\n",
" )\n",
" )\n",
" ]\n",
" )\n",
")[IndexKey.multi_axis(_MultiAxisIndexKey_1 + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(1))))))]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"egraph = EGraph([array_api_module])\n",
"egraph.register(X_r2)\n",
"egraph.run(10000)\n",
"X_r2_optimized = egraph.extract(X_r2)\n",
"X_r2_optimized"
]
},
{
"cell_type": "markdown",
"id": "30ea4ea4",
"metadata": {},
"source": [
"We see that for example expressions that referenced the shape of our input arrays have been resolved to their\n",
"values.\n",
"\n",
"We can also take a look at the e-graph itself, even though it's quite large, where we can see that equivalent\n",
"expressions show up in the same group, or \"e-class\":\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6417b9e5",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"egraph.display(n_inline_leaves=3, split_primitive_outputs=True)"
]
},
{
"cell_type": "markdown",
"id": "21e4ee3a",
"metadata": {},
"source": [
"## Translating for Numba\n",
"\n",
"We are getting closer to a form we could translate back to Numba, but we have to make a few changes. Numba doesn't\n",
"support the `axis` keyword for `mean` or `std`, but it does support it for `sum`, so we have to translate all forms\n",
"from one to the other, with a rule like this (defined in [`egglog.exp.array_api_numba`](https://github.com/egraphs-good/egglog-python/blob/main/python/egglog/exp/array_api_numba.py)):\n",
"\n",
"```python\n",
"axis = OptionalIntOrTuple.some(IntOrTuple.int(i))\n",
"rewrite(std(x, axis)).to(sqrt(mean(square(abs(x - mean(x, axis, keepdims=TRUE))), axis)))\n",
"```\n",
"\n",
"We can run those additional rewrites now to get a new extracted version\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9e79f88e",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"_NDArray_1 = NDArray . var ( "X" ) \n",
"assume_dtype ( _NDArray_1 , DType . float64 ) \n",
"assume_shape ( _NDArray_1 , TupleInt ( Int ( 1000000 )) + TupleInt ( Int ( 20 ))) \n",
"assume_isfinite ( _NDArray_1 ) \n",
"_NDArray_2 = NDArray . var ( "y" ) \n",
"assume_dtype ( _NDArray_2 , DType . int64 ) \n",
"assume_shape ( _NDArray_2 , TupleInt ( Int ( 1000000 ))) \n",
"assume_value_one_of ( _NDArray_2 , TupleValue ( Value . int ( Int ( 0 ))) + TupleValue ( Value . int ( Int ( 1 )))) \n",
"_NDArray_3 = astype ( \n",
" NDArray . vector ( TupleValue ( sum ( _NDArray_2 == NDArray . scalar ( Value . int ( Int ( 0 )))) . to_value ()) + TupleValue ( sum ( _NDArray_2 == NDArray . scalar ( Value . int ( Int ( 1 )))) . to_value ())), \n",
" DType . float64 , \n",
") / NDArray . scalar ( Value . float ( Float ( 1000000.0 ))) \n",
"_NDArray_4 = zeros ( TupleInt ( Int ( 2 )) + TupleInt ( Int ( 20 )), OptionalDType . some ( DType . float64 ), OptionalDevice . some ( _NDArray_1 . device )) \n",
"_MultiAxisIndexKey_1 = MultiAxisIndexKey ( MultiAxisIndexKeyItem . slice ( Slice ())) \n",
"_IndexKey_1 = IndexKey . multi_axis ( MultiAxisIndexKey ( MultiAxisIndexKeyItem . int ( Int ( 0 ))) + _MultiAxisIndexKey_1 ) \n",
"_NDArray_5 = _NDArray_1 [ ndarray_index ( _NDArray_2 == NDArray . scalar ( Value . int ( Int ( 0 ))))] \n",
"_OptionalIntOrTuple_1 = OptionalIntOrTuple . some ( IntOrTuple . int ( Int ( 0 ))) \n",
"_NDArray_4 [ _IndexKey_1 ] = sum ( _NDArray_5 , _OptionalIntOrTuple_1 ) / NDArray . scalar ( Value . int ( _NDArray_5 . shape [ Int ( 0 )])) \n",
"_IndexKey_2 = IndexKey . multi_axis ( MultiAxisIndexKey ( MultiAxisIndexKeyItem . int ( Int ( 1 ))) + _MultiAxisIndexKey_1 ) \n",
"_NDArray_6 = _NDArray_1 [ ndarray_index ( _NDArray_2 == NDArray . scalar ( Value . int ( Int ( 1 ))))] \n",
"_NDArray_4 [ _IndexKey_2 ] = sum ( _NDArray_6 , _OptionalIntOrTuple_1 ) / NDArray . scalar ( Value . int ( _NDArray_6 . shape [ Int ( 0 )])) \n",
"_NDArray_7 = concat ( TupleNDArray ( _NDArray_5 - _NDArray_4 [ _IndexKey_1 ]) + TupleNDArray ( _NDArray_6 - _NDArray_4 [ _IndexKey_2 ]), OptionalInt . some ( Int ( 0 ))) \n",
"_NDArray_8 = square ( _NDArray_7 - expand_dims ( sum ( _NDArray_7 , _OptionalIntOrTuple_1 ) / NDArray . scalar ( Value . int ( _NDArray_7 . shape [ Int ( 0 )])))) \n",
"_NDArray_9 = sqrt ( sum ( _NDArray_8 , _OptionalIntOrTuple_1 ) / NDArray . scalar ( Value . int ( _NDArray_8 . shape [ Int ( 0 )]))) \n",
"_NDArray_10 = copy ( _NDArray_9 ) \n",
"_NDArray_10 [ ndarray_index ( _NDArray_9 == NDArray . scalar ( Value . int ( Int ( 0 ))))] = NDArray . scalar ( Value . float ( Float ( 1.0 ))) \n",
"_TupleNDArray_1 = svd ( sqrt ( NDArray . scalar ( Value . float ( Float ( 1.0 ) / Float . from_int ( Int ( 999998 ))))) * ( _NDArray_7 / _NDArray_10 ), FALSE ) \n",
"_Slice_1 = Slice ( OptionalInt . none , OptionalInt . some ( sum ( astype ( _TupleNDArray_1 [ Int ( 1 )] > NDArray . scalar ( Value . float ( Float ( 0.0001 ))), DType . int32 )) . to_value () . to_int )) \n",
"_NDArray_11 = ( _TupleNDArray_1 [ Int ( 2 )][ IndexKey . multi_axis ( MultiAxisIndexKey ( MultiAxisIndexKeyItem . slice ( _Slice_1 )) + _MultiAxisIndexKey_1 )] / _NDArray_10 ) . T / _TupleNDArray_1 [ \n",
" Int ( 1 ) \n",
"][ IndexKey . slice ( _Slice_1 )] \n",
"_TupleNDArray_2 = svd ( \n",
" ( sqrt (( NDArray . scalar ( Value . int ( Int ( 1000000 ))) * _NDArray_3 ) * NDArray . scalar ( Value . float ( Float ( 1.0 )))) * ( _NDArray_4 - ( _NDArray_3 @ _NDArray_4 )) . T ) . T @ _NDArray_11 , FALSE \n",
") \n",
"( \n",
" ( _NDArray_1 - ( _NDArray_3 @ _NDArray_4 )) \n",
" @ ( \n",
" _NDArray_11 \n",
" @ _TupleNDArray_2 [ Int ( 2 )] . T [ \n",
" IndexKey . multi_axis ( \n",
" _MultiAxisIndexKey_1 \n",
" + MultiAxisIndexKey ( \n",
" MultiAxisIndexKeyItem . slice ( \n",
" Slice ( \n",
" OptionalInt . none , \n",
" OptionalInt . some ( \n",
" sum ( astype ( _TupleNDArray_2 [ Int ( 1 )] > ( NDArray . scalar ( Value . float ( Float ( 0.0001 ))) * _TupleNDArray_2 [ Int ( 1 )][ IndexKey . int ( Int ( 0 ))]), DType . int32 )) \n",
" . to_value () \n",
" . to_int \n",
" ), \n",
" ) \n",
" ) \n",
" ) \n",
" ) \n",
" ] \n",
" ) \n",
")[ IndexKey . multi_axis ( _MultiAxisIndexKey_1 + MultiAxisIndexKey ( MultiAxisIndexKeyItem . slice ( Slice ( OptionalInt . none , OptionalInt . some ( Int ( 1 ))))))] \n",
" \n"
],
"text/latex": [
"\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}1} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{X}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}dtype}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}shape}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1000000}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{20}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}isfinite}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{y}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}dtype}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int64}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}shape}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1000000}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}value\\PYZus{}one\\PYZus{}of}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{TupleValue}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleValue}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{=} \\PY{n}{astype}\\PY{p}{(}\n",
" \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{vector}\\PY{p}{(}\\PY{n}{TupleValue}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleValue}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{,}\n",
"\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1000000.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{=} \\PY{n}{zeros}\\PY{p}{(}\\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleInt}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{20}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{OptionalDType}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)}\\PY{p}{,} \\PY{n}{OptionalDevice}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{o}{.}\\PY{n}{device}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1} \\PY{o}{=} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{Slice}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}5} \\PY{o}{=} \\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\n",
"\\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1} \\PY{o}{=} \\PY{n}{OptionalIntOrTuple}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{IntOrTuple}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1}\\PY{p}{]} \\PY{o}{=} \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}6} \\PY{o}{=} \\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2}\\PY{p}{]} \\PY{o}{=} \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{=} \\PY{n}{concat}\\PY{p}{(}\\PY{n}{TupleNDArray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1}\\PY{p}{]}\\PY{p}{)} \\PY{o}{+} \\PY{n}{TupleNDArray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}6} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}8} \\PY{o}{=} \\PY{n}{square}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{\\PYZhy{}} \\PY{n}{expand\\PYZus{}dims}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}9} \\PY{o}{=} \\PY{n}{sqrt}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}8}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}8}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}10} \\PY{o}{=} \\PY{n}{copy}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}9}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}10}\\PY{p}{[}\\PY{n}{ndarray\\PYZus{}index}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}9} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1} \\PY{o}{=} \\PY{n}{svd}\\PY{p}{(}\\PY{n}{sqrt}\\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)} \\PY{o}{/} \\PY{n}{Float}\\PY{o}{.}\\PY{n}{from\\PYZus{}int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{999998}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}10}\\PY{p}{)}\\PY{p}{,} \\PY{n}{FALSE}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}Slice\\PYZus{}1} \\PY{o}{=} \\PY{n}{Slice}\\PY{p}{(}\\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZgt{}} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{0.0001}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int32}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}int}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}11} \\PY{o}{=} \\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}10}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T} \\PY{o}{/} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\n",
" \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\n",
"\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{]}\n",
"\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2} \\PY{o}{=} \\PY{n}{svd}\\PY{p}{(}\n",
" \\PY{p}{(}\\PY{n}{sqrt}\\PY{p}{(}\\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1000000}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{)} \\PY{o}{*} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{\\PYZhy{}} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}11}\\PY{p}{,} \\PY{n}{FALSE}\n",
"\\PY{p}{)}\n",
"\\PY{p}{(}\n",
" \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1} \\PY{o}{\\PYZhy{}} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{)}\\PY{p}{)}\n",
" \\PY{o}{@} \\PY{p}{(}\n",
" \\PY{n}{\\PYZus{}NDArray\\PYZus{}11}\n",
" \\PY{o}{@} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{o}{.}\\PY{n}{T}\\PY{p}{[}\n",
" \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\n",
" \\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1}\n",
" \\PY{o}{+} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\n",
" \\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\n",
" \\PY{n}{Slice}\\PY{p}{(}\n",
" \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,}\n",
" \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\n",
" \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZgt{}} \\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{0.0001}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int32}\\PY{p}{)}\\PY{p}{)}\n",
" \\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\n",
" \\PY{o}{.}\\PY{n}{to\\PYZus{}int}\n",
" \\PY{p}{)}\\PY{p}{,}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{p}{]}\n",
" \\PY{p}{)}\n",
"\\PY{p}{)}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{\\PYZus{}MultiAxisIndexKey\\PYZus{}1} \\PY{o}{+} \\PY{n}{MultiAxisIndexKey}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{Slice}\\PY{p}{(}\\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\n",
"\\end{Verbatim}\n"
],
"text/plain": [
"_NDArray_1 = NDArray.var(\"X\")\n",
"assume_dtype(_NDArray_1, DType.float64)\n",
"assume_shape(_NDArray_1, TupleInt(Int(1000000)) + TupleInt(Int(20)))\n",
"assume_isfinite(_NDArray_1)\n",
"_NDArray_2 = NDArray.var(\"y\")\n",
"assume_dtype(_NDArray_2, DType.int64)\n",
"assume_shape(_NDArray_2, TupleInt(Int(1000000)))\n",
"assume_value_one_of(_NDArray_2, TupleValue(Value.int(Int(0))) + TupleValue(Value.int(Int(1))))\n",
"_NDArray_3 = astype(\n",
" NDArray.vector(TupleValue(sum(_NDArray_2 == NDArray.scalar(Value.int(Int(0)))).to_value()) + TupleValue(sum(_NDArray_2 == NDArray.scalar(Value.int(Int(1)))).to_value())),\n",
" DType.float64,\n",
") / NDArray.scalar(Value.float(Float(1000000.0)))\n",
"_NDArray_4 = zeros(TupleInt(Int(2)) + TupleInt(Int(20)), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))\n",
"_MultiAxisIndexKey_1 = MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice()))\n",
"_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + _MultiAxisIndexKey_1)\n",
"_NDArray_5 = _NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(0))))]\n",
"_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))\n",
"_NDArray_4[_IndexKey_1] = sum(_NDArray_5, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_5.shape[Int(0)]))\n",
"_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(1))) + _MultiAxisIndexKey_1)\n",
"_NDArray_6 = _NDArray_1[ndarray_index(_NDArray_2 == NDArray.scalar(Value.int(Int(1))))]\n",
"_NDArray_4[_IndexKey_2] = sum(_NDArray_6, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_6.shape[Int(0)]))\n",
"_NDArray_7 = concat(TupleNDArray(_NDArray_5 - _NDArray_4[_IndexKey_1]) + TupleNDArray(_NDArray_6 - _NDArray_4[_IndexKey_2]), OptionalInt.some(Int(0)))\n",
"_NDArray_8 = square(_NDArray_7 - expand_dims(sum(_NDArray_7, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_7.shape[Int(0)]))))\n",
"_NDArray_9 = sqrt(sum(_NDArray_8, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_8.shape[Int(0)])))\n",
"_NDArray_10 = copy(_NDArray_9)\n",
"_NDArray_10[ndarray_index(_NDArray_9 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))\n",
"_TupleNDArray_1 = svd(sqrt(NDArray.scalar(Value.float(Float(1.0) / Float.from_int(Int(999998))))) * (_NDArray_7 / _NDArray_10), FALSE)\n",
"_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))\n",
"_NDArray_11 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_10).T / _TupleNDArray_1[\n",
" Int(1)\n",
"][IndexKey.slice(_Slice_1)]\n",
"_TupleNDArray_2 = svd(\n",
" (sqrt((NDArray.scalar(Value.int(Int(1000000))) * _NDArray_3) * NDArray.scalar(Value.float(Float(1.0)))) * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T).T @ _NDArray_11, FALSE\n",
")\n",
"(\n",
" (_NDArray_1 - (_NDArray_3 @ _NDArray_4))\n",
" @ (\n",
" _NDArray_11\n",
" @ _TupleNDArray_2[Int(2)].T[\n",
" IndexKey.multi_axis(\n",
" _MultiAxisIndexKey_1\n",
" + MultiAxisIndexKey(\n",
" MultiAxisIndexKeyItem.slice(\n",
" Slice(\n",
" OptionalInt.none,\n",
" OptionalInt.some(\n",
" sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))\n",
" .to_value()\n",
" .to_int\n",
" ),\n",
" )\n",
" )\n",
" )\n",
" )\n",
" ]\n",
" )\n",
")[IndexKey.multi_axis(_MultiAxisIndexKey_1 + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(1))))))]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from egglog.exp.array_api_numba import array_api_numba_module\n",
"\n",
"egraph = EGraph([array_api_numba_module])\n",
"egraph.register(X_r2_optimized)\n",
"egraph.run(10000)\n",
"X_r2_numba = egraph.extract(X_r2_optimized)\n",
"X_r2_numba"
]
},
{
"cell_type": "markdown",
"id": "969490bb",
"metadata": {},
"source": [
"## Compiling back to Python source\n",
"\n",
"Now we finally have a version that we could run with Numba! However, this isn't in NumPy code. What Numba needs\n",
"is a function that uses `numpy`, not our typed dialect.\n",
"\n",
"So we use another module that provides a translation of all our methods into Python strings. The rules in it look like this:\n",
"\n",
"```python\n",
"# the sqrt of an array should use the `np.sqrt` function and be assigned to its own variable, so it can be reused\n",
"rewrite(ndarray_program(sqrt(x))).to((Program(\"np.sqrt(\") + ndarray_program(x) + \")\").assign())\n",
"\n",
"# To compile a setitem call, we first compile the source, assign it to a variable, then add an assignment statement\n",
"mod_x = copy(x)\n",
"mod_x[idx] = y\n",
"assigned_x = ndarray_program(x).assign()\n",
"yield rewrite(ndarray_program(mod_x)).to(\n",
" assigned_x.statement(assigned_x + \"[\" + index_key_program(idx) + \"] = \" + ndarray_program(y))\n",
")\n",
"```\n",
"\n",
"We pull in all those rewrite rules from the [`egglog.exp.array_api_program_gen` module](https://github.com/egraphs-good/egglog-python/blob/main/python/egglog/exp/array_api_program_gen.py).\n",
"They depend on another module, [`egglog.exp.program_gen` module](https://github.com/egraphs-good/egglog-python/blob/main/python/egglog/exp/program_gen.py), which provides generic translations\n",
"from expressions and statements into strings.\n",
"\n",
"We can run these rules to get out a Python function object:\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3aeae673",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def __fn(X, y):\n",
" assert X.dtype == np.dtype(np.float64)\n",
" assert X.shape == (1000000, 20,)\n",
" assert np.all(np.isfinite(X))\n",
" assert y.dtype == np.dtype(np.int64)\n",
" assert y.shape == (1000000,)\n",
" assert set(np.unique(y)) == set((0, 1,))\n",
" _0 = y == np.array(0)\n",
" _1 = np.sum(_0)\n",
" _2 = y == np.array(1)\n",
" _3 = np.sum(_2)\n",
" _4 = np.array((_1, _3,)).astype(np.dtype(np.float64))\n",
" _5 = _4 / np.array(1000000.0)\n",
" _6 = np.zeros((2, 20,), dtype=np.dtype(np.float64))\n",
" _7 = np.sum(X[_0], axis=0)\n",
" _8 = _7 / np.array(X[_0].shape[0])\n",
" _6[0, :] = _8\n",
" _9 = np.sum(X[_2], axis=0)\n",
" _10 = _9 / np.array(X[_2].shape[0])\n",
" _6[1, :] = _10\n",
" _11 = _5 @ _6\n",
" _12 = X - _11\n",
" _13 = np.sqrt(np.array((1.0 / 999998)))\n",
" _14 = X[_0] - _6[0, :]\n",
" _15 = X[_2] - _6[1, :]\n",
" _16 = np.concatenate((_14, _15,), axis=0)\n",
" _17 = np.sum(_16, axis=0)\n",
" _18 = _17 / np.array(_16.shape[0])\n",
" _19 = np.expand_dims(_18, 0)\n",
" _20 = _16 - _19\n",
" _21 = np.square(_20)\n",
" _22 = np.sum(_21, axis=0)\n",
" _23 = _22 / np.array(_21.shape[0])\n",
" _24 = np.sqrt(_23)\n",
" _25 = _24 == np.array(0)\n",
" _24[_25] = np.array(1.0)\n",
" _26 = _16 / _24\n",
" _27 = _13 * _26\n",
" _28 = np.linalg.svd(_27, full_matrices=False)\n",
" _29 = _28[1] > np.array(0.0001)\n",
" _30 = _29.astype(np.dtype(np.int32))\n",
" _31 = np.sum(_30)\n",
" _32 = _28[2][:_31, :] / _24\n",
" _33 = _32.T / _28[1][:_31]\n",
" _34 = np.array(1000000) * _5\n",
" _35 = _34 * np.array(1.0)\n",
" _36 = np.sqrt(_35)\n",
" _37 = _6 - _11\n",
" _38 = _36 * _37.T\n",
" _39 = _38.T @ _33\n",
" _40 = np.linalg.svd(_39, full_matrices=False)\n",
" _41 = np.array(0.0001) * _40[1][0]\n",
" _42 = _40[1] > _41\n",
" _43 = _42.astype(np.dtype(np.int32))\n",
" _44 = np.sum(_43)\n",
" _45 = _33 @ _40[2].T[:, :_44]\n",
" _46 = _12 @ _45\n",
" return _46[:, :1]\n",
"\n"
]
}
],
"source": [
"from egglog.exp.array_api_program_gen import (\n",
" ndarray_function_two,\n",
" array_api_module_string,\n",
")\n",
"\n",
"egraph = EGraph([array_api_module_string])\n",
"fn_program = ndarray_function_two(X_r2_numba, X_orig, y_orig)\n",
"egraph.register(fn_program)\n",
"egraph.run(10000)\n",
"fn = egraph.load_object(egraph.extract(fn_program.py_object))\n",
"import inspect\n",
"\n",
"print(inspect.getsource(fn))"
]
},
{
"cell_type": "markdown",
"id": "6e0405c8",
"metadata": {},
"source": [
"We can verify that the function gives the same result:\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a807d66c",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np))"
]
},
{
"cell_type": "markdown",
"id": "b2a3f1ed",
"metadata": {},
"source": [
"Although it isn't the prettiest, we can see that it has only emitted each expression once, for common subexpression\n",
"elimination, and preserves the \"imperative\" aspects of setitem.\n",
"\n",
"## Compiling to Numba\n",
"\n",
"Now we finally have a function we can run with numba:\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "39a69f23",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/egglog-9e61d62c-d17d-495b-b8db-f1eb3b38dcbb.py:56: NumbaPerformanceWarning: '@' is faster on contiguous arrays, called on (Array(float64, 2, 'C', False, aligned=True), Array(float64, 2, 'A', False, aligned=True))\n",
" _45 = _33 @ _40[2].T[:, :_44]\n"
]
}
],
"source": [
"import numba\n",
"import os\n",
"\n",
"fn_numba = numba.njit(fastmath=True)(fn)\n",
"assert np.allclose(run_lda(X_np, y_np), fn_numba(X_np, y_np))"
]
},
{
"cell_type": "markdown",
"id": "078d41b3",
"metadata": {},
"source": [
"## Evaluating performance\n",
"\n",
"Let's see if it actually made anything quicker! Let's run a number of trials for the original function, our\n",
"extracted version, and the optimized extracted version:\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "27a0cafc",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" original \n",
" extracted \n",
" extracted numba \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 1.482975 \n",
" 1.609354 \n",
" 1.086486 \n",
" \n",
" \n",
" 1 \n",
" 1.498656 \n",
" 1.504704 \n",
" 1.145331 \n",
" \n",
" \n",
" 2 \n",
" 1.500998 \n",
" 1.557253 \n",
" 1.090356 \n",
" \n",
" \n",
" 3 \n",
" 1.519732 \n",
" 1.548800 \n",
" 1.122623 \n",
" \n",
" \n",
" 4 \n",
" 1.500420 \n",
" 1.501195 \n",
" 1.113089 \n",
" \n",
" \n",
" 5 \n",
" 1.587211 \n",
" 1.522518 \n",
" 1.176842 \n",
" \n",
" \n",
" 6 \n",
" 1.499479 \n",
" 1.526887 \n",
" 1.095296 \n",
" \n",
" \n",
" 7 \n",
" 1.639910 \n",
" 1.500859 \n",
" 1.086477 \n",
" \n",
" \n",
" 8 \n",
" 1.525145 \n",
" 1.559202 \n",
" 1.103662 \n",
" \n",
" \n",
" 9 \n",
" 1.535601 \n",
" 1.474299 \n",
" 1.074152 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" original extracted extracted numba\n",
"0 1.482975 1.609354 1.086486\n",
"1 1.498656 1.504704 1.145331\n",
"2 1.500998 1.557253 1.090356\n",
"3 1.519732 1.548800 1.122623\n",
"4 1.500420 1.501195 1.113089\n",
"5 1.587211 1.522518 1.176842\n",
"6 1.499479 1.526887 1.095296\n",
"7 1.639910 1.500859 1.086477\n",
"8 1.525145 1.559202 1.103662\n",
"9 1.535601 1.474299 1.074152"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import timeit\n",
"import pandas as pd\n",
"\n",
"stmts = {\n",
" \"original\": \"run_lda(X_np, y_np)\",\n",
" \"extracted\": \"fn(X_np, y_np)\",\n",
" \"extracted numba\": \"fn_numba(X_np, y_np)\",\n",
"}\n",
"df = pd.DataFrame.from_dict(\n",
" {name: timeit.repeat(stmt, globals=globals(), number=1, repeat=10) for name, stmt in stmts.items()}\n",
")\n",
"\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "9488c513",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeoAAAHqCAYAAADLbQ06AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAs/0lEQVR4nO3de3TU5Z3H8c8vgQzkNiEICZeBKBBEKoEN4oW2EFqMWjmiZ0VB1iDUiqDUpVDN1gpYK6utAioIu1WR6Gq9ge5WxQuEmxblEsSIQGJoAEMQhAwBDZA8+weHqYEkJDCXZzLv1zlzDr/bzDeZYT55nt/z/H6OMcYIAABYKSrUBQAAgPoR1AAAWIygBgDAYgQ1AAAWI6gBALAYQQ0AgMUIagAALEZQAwBgsYgLamOMvF6vuM4LACAcRFxQHzp0SG63W4cOHQp1KQAAnFHEBTUAAOGEoAYAwGIENQAAFiOoAQCwGEENAIDFCGoAACxGUAMAYDGCGgAAixHUAABYjKAGAMBiBDUAABYjqAEAsBhBDQCAxQhqAAAsRlADAGCxFqEuAKG3tHCP5i0v0rbySqWnxGtCVndl904NdVkAAEmOMcaEuohg8nq9crvdqqioUGJiYqjLCbmlhXt0R976WuscR5o/OpOwBgAL0PUd4eYtLzptnTHSvPziEFQDADgVQR3htpVX1rl+e/mhIFcCAKgLQR3h0lPi61zfIyUhyJUAAOpCUEe4CVnd5Ti11zmONHFwt9AUBACohaCOcNm9UzV/dKYyPEmKjYlWhidJC0Zn6koGkgGAFRj1DQCAxWhRAwBgMYIaAACLEdQAAFiMoAYAwGIENQAAFiOoAQCwGEENAIDFCGoAACxGUAMAYDGCGgAAixHUAABYjKAGAMBiBDUAABZrEeoCAISXpYV7NG95kbaVVyo9JV4Tsrorm9uiAgHDbS4BNNrSwj26I299rXWOI80fnUlYAwFC1zeARpu3vOi0dcZI8/KLQ1ANEBkIagCNtq28ss7128sPBbkSIHIQ1AAaLT0lvs71PVISglwJEDlCGtQrV67UsGHD1LFjRzmOoyVLlpzxmKqqKv3ud79T165d5XK5lJaWpmeffTbwxQLQhKzucpza6xxHmji4W2gKAiJASEd9Hz58WBkZGRo7dqxuuOGGRh0zYsQIlZeX65lnnlH37t1VVlammpqaAFcKQJKye6dq/uhMzcsv1vbyQ+qRkqCJg7vpSgaSAQFjzahvx3G0ePFiDR8+vN593n33Xd1888366quvlJycfFavw6hvAEA4Catz1G+99Zb69++vRx99VJ06dVJ6erqmTJmi7777LtSlAQAQEGF1wZOvvvpKq1evVqtWrbR48WLt27dPEyZM0P79+/Xcc8/VeUxVVZWqqqp8y16vN1jlAgBwzsKqRV1TUyPHcfTiiy9qwIABuuaaa/T444/r+eefr7dVPXPmTLndbt/D4/EEuWoAAM5eWAV1hw4d1KlTJ7ndbt+6Xr16yRijXbt21XlMbm6uKioqfI+dO3cGq1wAAM5ZWAX1wIED9fXXX6uy8p8XXdi2bZuioqLUuXPnOo9xuVxKTEys9QAAIFyENKgrKytVUFCggoICSVJJSYkKCgpUWloq6URr+NZbb/XtP2rUKLVt21a33XabvvjiC61cuVJTp07V2LFj1bp161D8CAAABFRIg3rdunXq16+f+vXrJ0maPHmy+vXrpwceeECSVFZW5gttSYqPj9f777+vgwcPqn///rrllls0bNgwPfHEEyGpHwCAQLNmHnWwMI8aABBOwmp6FgKD+wsDgL1oUUc47i8MAHYLq1Hf8D/uLwwAdiOoIxz3FwYAuxHUEY77CwOA3QjqCMf9hQHAbgR1hDt5f+EMT5JiY6KV4UnSgtGZ3F8YACzBqG8AACxGixoAAIsR1AAAWIygBgDAYgQ1AAAWI6gBALAYQQ0AgMUIagAALEZQAwBgMYIaAACLEdQAAFiMoAYAwGItQl0AgPC1tHCP5i0v0rbySqWnxGtCVndlc0MXwK+4KQeABtUXxksL9+iOvPW19nUcaf7oTMIa8CO6vgHU62QYb9pVoe+OVWvTrgqNf2G9L7xPZYw0L784BJUCzRdBDaBeDYXxtvLKOo/ZXn4o0GUBEYWgBlCvhsI4PSW+zm09UhICWRIQcQhqAPVqKIwnZHWX49Re7zjSxMHdglAZEDkIagD1aiiMs3unav7oTGV4khQbE60MT5IWjM7UlQwkA/yKUd8AGrS0cI/m5Rdre/kh9UhJ0BUXtNVHxfuYkgUECUENoNGYkgUEH13fABqNKVlA8BHUABqNKVlA8BHUABqNKVlA8BHUABqNKVlA8BHUABqNKVlA8DHqGwAAi9GiBgDAYtyPGtxTGAAsRtd3hOMCFgBgN7q+IxwXsAAAuxHUEY4LWACA3QjqCMcFLADAbgR1hOMCFgBgNwaT4bTbGE4c3I0LWEQ4ZgIA9iCoUS++rCMTMwEAu9D1jTqd/LLetKtC3x2r1qZdFRr/wnotLdwT6tIQYMwEAOzCBU9Qp4a+rGlVNW9nmglATwsQXLSoUSembUWuhmYC0NMCBB9BDS0t3KPrnlqtXr9/V9c9tVpLC/cwbSuCNTQTgG5xIPgI6ghXXwvpim7nMW0rQjV0K0t6WoDg4xx1hKuvhfTRV/s1f3Qm07YiVHbv1DrPO6enxGvTrorT1tPTAgQOQR3hGmoh1fdljcg1Iau7xr+wXj+c1ElPCxBYdH1HOM5Foyka6hYHEBi0qCMcLSQ0FT0tQHDRoo5wtJAAwG5cQhQAAIvRogYAwGIENQAAFmMwGQC/4BrgQGBwjhrAOePWmEDghLTre+XKlRo2bJg6duwox3G0ZMmSBvfPz8+X4zinPfbs4YYAQChxDXAgcELa9X348GFlZGRo7NixuuGGGxp93NatW2u1htu3bx+I8sJOQ12PdEuisc70WalrO9cABwLHmq5vx3G0ePFiDR8+vN598vPzlZWVpQMHDigpKemsXqe5dn031PUoqcFuSQIeJ52pC7u+7V2TY7Vj/5HTni/Dk6Q3Jw4MaM1AcxeWg8n69u2rqqoq/ehHP9L06dM1cGD9XwRVVVWqqqryLXu93mCUGHQNdj3W8bfYD7slf/jFe/LuWXUF/A+3EdbNU0Ofo+zeqfVul04ENle4A/wvrKZndejQQfPnz9frr7+u119/XR6PR4MHD9aGDRvqPWbmzJlyu92+h8fjCWLFwdNQ12ND2xr6Yua8Y+Q5Uxd2fdv3HqriCndAgIRVi7pnz57q2bOnb/mKK65QcXGxZs2apby8vDqPyc3N1eTJk33LXq+3WYZ1g7cfNKbebdv21H0OcXv5oboa4r5taJ7OdBvLhrZzDXAgMMKqRV2XAQMGqKjo9JbfSS6XS4mJibUezdGErO5ynNrrTnY9NrStobtncWetyNPQZ6Ux2wH4X9gHdUFBgTp06BDqMkKuoZtrNLTtbAMezdOZbtLCTVyA4AvpqO/Kykpfa7hfv356/PHHlZWVpeTkZHXp0kW5ubnavXu3Fi1aJEmaPXu2zj//fPXu3Vvff/+9/vKXv+jJJ5/Ue++9p5/97GeNes3mOur7XCwt3KN5+cXaXn5IPVISNHFwN98Xb0PbAACBF9Jz1OvWrVNWVpZv+eS55JycHC1cuFBlZWUqLS31bT969Kh+85vfaPfu3YqNjVWfPn30wQcf1HoONF1D5xY57wgAoWXNPOpgoUUNAAgnYX+OGgCA5oygBgDAYmE1jxqBwWVCAcBenKOOcNyeEADsRtd3hOMyoQBgN4I6wnF7QgCwG0Ed4bhMKADYjaCOcFwmFADsxmAycJlQNAmzBIDgIqgBNBqzBIDgo+sbQKMxSwAIPoIaQKMxSwAIPoIaQKMxSwAIPoIaQKMxSwAIPoIaQKNl907V/NGZyvAkKTYmWhmeJC0YncksASCAGPUNAIDFaFEDAGAxghoAAIsR1AAAWIygBgDAYgQ1AAAWI6gBALAYQQ0AgMUIagAALEZQAwBgMYIaAACLEdQAAFiMoAYAwGIENQAAFiOoAQCwGEENAIDFCGoAACxGUAMAYDGCGgAAixHUAABYjKAGAMBiBDUAABYjqAEAsBhBDQCAxQhqAAAsRlADAGAxghoAAIsR1AAAWIygBgDAYgQ1AAAWI6gBALAYQQ0AgMUIagAALEZQAwBgMYIaAACLEdQAAFiMoAYAwGIENQAAFiOoAQCwGEENAIDFCGoAACxGUAMAYLGQBvXKlSs1bNgwdezYUY7jaMmSJY0+ds2aNWrRooX69u0bsPoAAAi1kAb14cOHlZGRoblz5zbpuIMHD+rWW2/Vz372swBVBgCAHVqE8sWvvvpqXX311U0+bvz48Ro1apSio6Ob1AoHACDchN056ueee05fffWVpk2bFupSAAAIuJC2qJtq+/btuu+++7Rq1Sq1aNG40quqqlRVVeVb9nq9gSoPAAC/C5sWdXV1tUaNGqUZM2YoPT290cfNnDlTbrfb9/B4PAGsEgAA/3KMMSbURUiS4zhavHixhg8fXuf2gwcPqk2bNoqOjvatq6mpkTFG0dHReu+99zRkyJDTjqurRe3xeFRRUaHExES//xwAAPhT2HR9JyYmavPmzbXWzZs3T8uWLdNrr72m888/v87jXC6XXC5XMEoEAMDvQhrUlZWVKioq8i2XlJSooKBAycnJ6tKli3Jzc7V7924tWrRIUVFR+tGPflTr+Pbt26tVq1anrQcAoLkIaVCvW7dOWVlZvuXJkydLknJycrRw4UKVlZWptLQ0VOUBABBy1pyjDhav1yu32805agBAWAibUd8AAEQighoAAIsR1AAAWIygBgDAYgQ1AAAWI6gBALAYQQ0AgMUIagAALEZQAwBgMYIaAACLnVVQFxcX6/7779fIkSO1d+9eSdI777yjwsJCvxYHAECka3JQr1ixQhdffLHWrl2rN954Q5WVlZKkTZs2adq0aX4vEACASNbkoL7vvvv00EMP6f3331dMTIxv/ZAhQ/T3v//dr8UBABDpmhzUmzdv1vXXX3/a+vbt22vfvn1+KQoAAJzQ5KBOSkpSWVnZaes3btyoTp06+aUoAABwQpOD+uabb9a9996rPXv2yHEc1dTUaM2aNZoyZYpuvfXWQNQIAEDEcowxpikHHD16VBMnTtTChQtVXV2tFi1aqLq6WqNGjdLChQsVHR0dqFr9wuv1yu12q6KiQomJiaEuBwCABjU5qE8qLS3V559/rsrKSvXr1089evTwd20BQVADAMLJWQd1uCKoAQDhpEVTDzDG6LXXXtPy5cu1d+9e1dTU1Nr+xhtv+K04AAAiXZOD+p577tGCBQuUlZWllJQUOY4TiLoAAIDOous7OTlZL7zwgq655ppA1RRQdH0DAMJJk6dnud1uXXDBBYGoBQAAnKLJQT19+nTNmDFD3333XSDqAQAAP9Dkc9QjRozQSy+9pPbt2ystLU0tW7astX3Dhg1+Kw4AgEjX5KDOycnR+vXrNXr0aAaTAQAQYE0eTBYXF6elS5fqxz/+caBqCigGkwEAwkmTz1F7PB4CDgCAIGlyUD/22GP67W9/qx07dgSgHAAA8ENN7vpu06aNjhw5ouPHjys2Nva0wWTffvutXwv0N7q+AQDhpMmDyWbPnh2AMgAAQF24KQcAABZrVIva6/X6Qs3r9Ta4L+EHAID/NCqo27Rpo7KyMrVv315JSUl1zp02xshxHFVXV/u9SAAAIlWjgnrZsmVKTk6WJD333HPyeDyKjo6utU9NTY1KS0v9XyEAABGsyeeoo6Ojfa3rH9q/f7/at29vfYuac9QAgHDS5HnUJ7u4T1VZWalWrVr5pSgAAHBCo6dnTZ48WZLkOI5+//vfKzY21returpaa9euVd++ff1eIAAAkazRQb1x40ZJJ1rUmzdvVkxMjG9bTEyMMjIyNGXKFP9XCABABGvyOerbbrtNc+bMCdvzu5yjBgCEEy54AgCAxZo8mAwAAAQPQQ0AgMUIagAALEZQAwBgMYIaAACLEdQAAFiMoAYAwGIENQAAFiOoAQCwGEENAIDFCGoAACxGUAMAYDGCGgAAixHUAABYjKAGAMBiBDUAABYjqAEAsFhIg3rlypUaNmyYOnbsKMdxtGTJkgb3X716tQYOHKi2bduqdevWuvDCCzVr1qzgFAsAQAi0COWLHz58WBkZGRo7dqxuuOGGM+4fFxenu+66S3369FFcXJxWr16tO+64Q3FxcfrVr34VhIoBAAguxxhjQl2EJDmOo8WLF2v48OFNOu6GG25QXFyc8vLyGrW/1+uV2+1WRUWFEhMTz6JSAACCJ6zPUW/cuFEfffSRBg0aVO8+VVVV8nq9tR4AAISLsAzqzp07y+VyqX///po4caJ++ctf1rvvzJkz5Xa7fQ+PxxPESgEAODdhGdSrVq3SunXrNH/+fM2ePVsvvfRSvfvm5uaqoqLC99i5c2cQKwUA4NyEdDDZ2Tr//PMlSRdffLHKy8s1ffp0jRw5ss59XS6XXC5XMMsDAMBvwrJF/UM1NTWqqqoKdRkAAARESFvUlZWVKioq8i2XlJSooKBAycnJ6tKli3Jzc7V7924tWrRIkjR37lx16dJFF154oaQT87D//Oc/a9KkSSGpHwCAQAtpUK9bt05ZWVm+5cmTJ0uScnJytHDhQpWVlam0tNS3vaamRrm5uSopKVGLFi3UrVs3PfLII7rjjjuCXjsAAMFgzTzqYGEeNQAgnIT9OWoAAJozghoAAIsR1AAAWIygBgDAYgQ1AAAWI6gBALAYQQ0AgMUIagAALEZQAwBgMYIaAACLEdQAAFiMoAYAwGIENQAAFiOoAQCwGEENAIDFCGoAACxGUAMAYDGCGgAAixHUAABYjKAGAMBiBDUAABYjqAEAsBhBDQCAxQhqAAAsRlADAGAxghoAAIsR1AAAWIygBgDAYgQ1AAAWI6gBALAYQQ0AgMUIagAALEZQAwBgMYIaAACLEdQAAFiMoAYAwGIENQAAFiOoAQCwGEENAIDFCGoAACxGUAMAYDGCGgAAixHUAABYjKAGAMBiBDUAABYjqAEAsBhBDQCAxQhqAAAsRlADAGAxghoAAIsR1AAAWIygBgDAYgQ1AAAWI6gBALAYQQ0AgMUIagAALEZQAwBgMYIaAACLhTSoV65cqWHDhqljx45yHEdLlixpcP833nhDQ4cOVbt27ZSYmKjLL79cS5cuDU6xAACEQEiD+vDhw8rIyNDcuXMbtf/KlSs1dOhQvf3221q/fr2ysrI0bNgwbdy4McCVAgAQGo4xxoS6CElyHEeLFy/W8OHDm3Rc7969ddNNN+mBBx5o1P5er1dut1sVFRVKTEw8i0oBAAieFqEu4FzU1NTo0KFDSk5OrnefqqoqVVVV+Za9Xm8wSgMAwC/CejDZn//8Z1VWVmrEiBH17jNz5ky53W7fw+PxBLFCAADOTdgG9f/8z/9oxowZeuWVV9S+fft698vNzVVFRYXvsXPnziBWCQDAuQnLru+XX35Zv/zlL/Xqq6/q5z//eYP7ulwuuVyuIFUGAIB/hV2L+qWXXtJtt92ml156Sb/4xS9CXQ4AAAEV0hZ1ZWWlioqKfMslJSUqKChQcnKyunTpotzcXO3evVuLFi2SdKK7OycnR3PmzNGll16qPXv2SJJat24tt9sdkp8BAIBACun0rPz8fGVlZZ22PicnRwsXLtSYMWO0Y8cO5efnS5IGDx6sFStW1Lt/YzA9CwAQTqyZRx0sBDUAIJyE3TlqAAAiCUENAIDFCGoAACxGUAMAYDGCGgAAixHUAABYLCwvIQoAsN/Swj2at7xI28orlZ4SrwlZ3ZXdOzXUZYUd5lEDAPxuaeEe3ZG3vtY6x5Hmj84krJuIrm8AgN/NW1502jpjpHn5xSGoJrwR1AAAv9tWXlnn+u3lh4JcSfgjqAEAfpeeEl/n+h4pCUGuJPwR1AAAv5uQ1V2OU3ud40gTB3cLTUFhjKAGAPhddu9UzR+dqQxPkmJjopXhSdKC0Zm6koFkTcaobwAALEaLGgAAixHUAABYjKAGAMBiBDUAABYjqAEAsBhBDQCAxQhqAAAsRlADAGAxghoAAIsR1AAAWIygBgDAYgQ1AAAWI6gBALAYQQ0AgMVahLoAAEDztLRwj+YtL9K28kqlp8RrQlZ3ZXM/6ibjftQAAL9bWrhHd+Str7XOcaT5ozMJ6yai6xsA4Hfzlhedts4YaV5+cQiqCW8ENQDA77aVV9a5fnv5oSBXEv4IagCA36WnxNe5vkdKQpArCX8ENQDA7yZkdZfj1F7nONLEwd1CU1AYI6gBAH6X3TtV80dnKsOTpNiYaGV4krRgdKauZCBZkzHqGwAAi9GiBgDAYgQ1AAAWI6gBALAYQQ0AgMUIagAALEZQAwBgMYIaAACLcZtLAEDAcKvLc8cFTwAAAcGtLv2Drm8AQEBwq0v/IKgBAAHBrS79g6AGAAQEt7r0D4IaABAQ3OrSPwhqAEBAcKtL/2B6FgDAr5iS5V9MzwIA+A1TsvyPrm8AgN8wJcv/CGoAgN8wJcv/CGoAgN8wJcv/CGoAgN8wJcv/GEwGAPCrpYV7NC+/WNvLD6l9gkuSVO6tYgT4WQppi3rlypUaNmyYOnbsKMdxtGTJkgb3Lysr06hRo5Senq6oqCjdc889QakTANB42b1T9ebEgZp1U1/t2H9EO/Yf0XfHqrVpV4XGv7BeSwv3hLrEsBLSoD58+LAyMjI0d+7cRu1fVVWldu3a6f7771dGRkaAqwMAnAtGgPtHSC94cvXVV+vqq69u9P5paWmaM2eOJOnZZ58NVFkAAD9gBLh/MJgMABAQjAD3j2Yf1FVVVfJ6vbUeAIDAYwS4fzT7a33PnDlTM2bMCHUZABARTr3O9/ifdtNHX+3X9vJD6pGSoImDu3FTjiZq9kGdm5uryZMn+5a9Xq88Hk8IKwKA5unU63xv2lWhz3ZXcJ3vc9Tsg9rlcsnlcoW6DABo9hoa5U1Qn72QBnVlZaWKiv75xpaUlKigoEDJycnq0qWLcnNztXv3bi1atMi3T0FBge/Yb775RgUFBYqJidFFF10U7PIBAD/QlFHe3Aqz8UJ6ZbL8/HxlZWWdtj4nJ0cLFy7UmDFjtGPHDuXn5/u2OaeOTJDUtWtX7dixo1GvyZXJACAwrntqtTbtqjhtfYYnSW9OHOhb5laYTcMlRAEAfrG0cI/Gv7BeP0wVx5EWjM6sNYCssYGOE5r9OWoAQOA0NMr75HW+f/1yQa3ubS6E0jQENQDgrDQ0ylvSadvGv7Be80dnKj0lvs4WNRdCqVuzv+AJACAwGhrl3dA2LoTSNAQ1AOCsNNSF3dC27N6pmj86UxmeJMXGRCvDk3TaeWz8E13fAICz0mAXtjENdm9n905lhHcj0aIGAJyVhrqw6d72H6ZnAQDO2tLCPZqXX1zntbwb2obGI6gBALAY56gBAAHBZUL9gxY1AMDvuEyo/zCYDADgdw3No0bTENQAAL/jMqH+Q1ADAPwuPSW+zvVcJrTpCGoAgN8xj9p/CGoAgN9xmVD/YdQ3AAAWo0UNAIDFCGoAACxGUAMAYDGCGgAAixHUAABYjKAGAMBiBDUAABYjqAEAsBhBDQCAxQhqAAAsRlADAGAxghoAAIsR1AAAWIygBgDAYi1CXUCwnbyrp9frDXElAIBIl5CQIMdxGtwn4oL60KFDkiSPxxPiSgAAka6iokKJiYkN7uOYk03MCFFTU6Ovv/66UX/FRBKv1yuPx6OdO3ee8UMD8HlBY/FZaRgt6jpERUWpc+fOoS7DWomJifxnQqPxeUFj8Vk5ewwmAwDAYgQ1AAAWI6ghSXK5XJo2bZpcLleoS0EY4POCxuKzcu4ibjAZAADhhBY1AAAWI6gBALAYQd3MTZ8+XX379m3SMYMHD9Y999wT8jqAugTi84ngsvk9tPG7iqBu5qZMmaIPP/ywSce88cYb+sMf/hCgimCD/Px8OY6jgwcPBuX1bP5iDle8h5Ej4i54EimMMaqurlZ8fLzi4+ObdGxycnKAqkK4OXr0qGJiYkJdBs4B72H4o0UdRqqqqjRp0iS1b99erVq10o9//GN9+umnkv751/U777yjzMxMuVwurV69+rRunOPHj2vSpElKSkpS27Ztde+99yonJ0fDhw/37XPqX85paWl6+OGHNXbsWCUkJKhLly76r//6r1q13XvvvUpPT1dsbKwuuOAC/f73v9exY8cC+euIeDU1NZo5c6bOP/98tW7dWhkZGXrttddkjNHPf/5zZWdn+25C8+2336pz58564IEHtGPHDmVlZUmS2rRpI8dxNGbMGEkn3vu77rpL99xzj8477zxlZ2dLkh5//HFdfPHFiouLk8fj0YQJE1RZWVmrnjVr1mjw4MGKjY1VmzZtlJ2drQMHDmjMmDFasWKF5syZI8dx5DiOduzYIUn6/PPPdfXVVys+Pl4pKSn6t3/7N+3bt8/3nIcPH9att96q+Ph4dejQQY899liAf6vBU9/7Jyni38OT31t5eXlKS0uT2+3WzTff7LtXg3Tie2n27Nm1juvbt6+mT5/uW3YcRwsWLNC1116r2NhY9erVSx9//LGKioo0ePBgxcXF6YorrlBxcfFpNSxYsEAej0exsbEaMWKEKioqfNs+/fRTDR06VOedd57cbrcGDRqkDRs2nPHnOmsGYWPSpEmmY8eO5u233zaFhYUmJyfHtGnTxuzfv98sX77cSDJ9+vQx7733nikqKjL79+8306ZNMxkZGb7neOihh0xycrJ54403zJYtW8z48eNNYmKiue6663z7DBo0yPz617/2LXft2tUkJyebuXPnmu3bt5uZM2eaqKgo8+WXX/r2+cMf/mDWrFljSkpKzFtvvWVSUlLMI4884tt+ah04dw899JC58MILzbvvvmuKi4vNc889Z1wul8nPzze7du0ybdq0MbNnzzbGGHPjjTeaAQMGmGPHjpnjx4+b119/3UgyW7duNWVlZebgwYPGmBPvfXx8vJk6dar58ssvfe/xrFmzzLJly0xJSYn58MMPTc+ePc2dd97pq2Xjxo3G5XKZO++80xQUFJjPP//cPPnkk+abb74xBw8eNJdffrm5/fbbTVlZmSkrKzPHjx83Bw4cMO3atTO5ublmy5YtZsOGDWbo0KEmKyvL97x33nmn6dKli/nggw/MZ599Zq699lqTkJBQ6/MZrhp6/4wxEf0eTps2zcTHx5sbbrjBbN682axcudKkpqaa//iP//Dt07VrVzNr1qxax2VkZJhp06b5liWZTp06mb/+9a9m69atZvjw4SYtLc0MGTLEvPvuu+aLL74wl112mbnqqqtqvXZcXJwZMmSI2bhxo1mxYoXp3r27GTVqlG+fDz/80OTl5ZktW7aYL774wowbN86kpKQYr9fbyHe/aQjqMFFZWWlatmxpXnzxRd+6o0ePmo4dO5pHH33UF9RLliypddypAZmSkmL+9Kc/+ZaPHz9uunTpcsagHj16tG+5pqbGtG/f3jz99NP11vunP/3JZGZm1lsHzs33339vYmNjzUcffVRr/bhx48zIkSONMca88sorplWrVua+++4zcXFxZtu2bb79Tn5eDhw4UOv4QYMGmX79+p3x9V999VXTtm1b3/LIkSPNwIED693/1M+UMSf+uLvyyitrrdu5c6cvfA4dOmRiYmLMK6+84tu+f/9+07p167AP6sa8f8ZE7ns4bdo0ExsbWyv4pk6dai699FLfcmOD+v777/ctf/zxx0aSeeaZZ3zrXnrpJdOqVatarx0dHW127drlW/fOO++YqKgoU1ZWVme91dXVJiEhwfzv//5vvT/TueAcdZgoLi7WsWPHNHDgQN+6li1basCAAdqyZYsuueQSSVL//v3rfY6KigqVl5drwIABvnXR0dHKzMxUTU1Ng6/fp08f378dx1Fqaqr27t3rW/fXv/5VTzzxhIqLi1VZWanjx49zAf4AKioq0pEjRzR06NBa648ePap+/fpJkm688UYtXrxY//mf/6mnn35aPXr0aNRzZ2Zmnrbugw8+0MyZM/Xll1/K6/Xq+PHj+v7773XkyBHFxsaqoKBAN954Y5N+hk2bNmn58uV1jqEoLi7Wd999p6NHj+rSSy/1rU9OTlbPnj2b9Do2asz7J0X2e5iWlqaEhATfcocOHWp95zTWD7+7UlJSJEkXX3xxrXXff/+9vF6v7zurS5cu6tSpk2+fyy+/XDU1Ndq6datSU1NVXl6u+++/X/n5+dq7d6+qq6t15MgRlZaWNrm+xiCom5m4uLiAPG/Lli1rLTuO4wv3jz/+WLfccotmzJih7Oxsud1uvfzyy83qfKJtTp5b/Nvf/lbrC0WS71KNR44c0fr16xUdHa3t27c3+rlP/Qzt2LFD1157re6880798Y9/VHJyslavXq1x48bp6NGjio2NVevWrc/qZxg2bJgeeeSR07Z16NBBRUVFTX7OcNGY90+K7Pewoe8c6cSdEM0pF9asa1zMD5/n5O0k61p3psbKD+Xk5Gj//v2aM2eOunbtKpfLpcsvv1xHjx5t9HM0BYPJwkS3bt0UExOjNWvW+NYdO3ZMn376qS666KJGPYfb7VZKSopvAJokVVdXn/MgiI8++khdu3bV7373O/Xv3189evTQP/7xj3N6TjTsoosuksvlUmlpqbp3717r4fF4JEm/+c1vFBUVpXfeeUdPPPGEli1b5jv+5Cjg6urqM77W+vXrVVNTo8cee0yXXXaZ0tPT9fXXX9fap0+fPg1OA4yJiTnttf7lX/5FhYWFSktLO+1niIuLU7du3dSyZUutXbvWd8yBAwe0bdu2M/+CLNeY90/iPWxIu3btVFZW5lv2er0qKSk55+eVpNLS0lq/n7///e+Kiory9QSsWbNGkyZN0jXXXKPevXvL5XLVGkDnbwR1mIiLi9Odd96pqVOn6t1339UXX3yh22+/XUeOHNG4ceMa/Tx33323Zs6cqTfffFNbt27Vr3/9ax04cOCMNy5vSI8ePVRaWqqXX35ZxcXFeuKJJ7R48eKzfj6cWUJCgqZMmaJ///d/1/PPP6/i4mJt2LBBTz75pJ5//nn97W9/07PPPqsXX3xRQ4cO1dSpU5WTk6MDBw5Ikrp27SrHcfR///d/+uabb04b/ftD3bt317Fjx/Tkk0/qq6++Ul5enubPn19rn9zcXH366aeaMGGCPvvsM3355Zd6+umnfV9eaWlpWrt2rXbs2KF9+/appqZGEydO1LfffquRI0fq008/VXFxsZYuXarbbrvNN7Vw3Lhxmjp1qpYtW6bPP/9cY8aMUVRU+H9tnen9k8R7eAZDhgxRXl6eVq1apc2bNysnJ0fR0dHn/LyS1KpVK+Xk5GjTpk1atWqVJk2apBEjRig1NVXSie+8vLw8bdmyRWvXrtUtt9xyVj0SjRaQM98IiO+++87cfffd5rzzzjMul8sMHDjQfPLJJ8aY+geWnDqI69ixY+auu+4yiYmJpk2bNubee+81N954o7n55pt9+9Q1mOxMgzamTp1q2rZta+Lj481NN91kZs2aZdxud7114NzV1NSY2bNnm549e5qWLVuadu3amezsbJOfn29SUlLMww8/7Nv36NGjJjMz04wYMcK37sEHHzSpqanGcRyTk5NjjKl7wJAxxjz++OOmQ4cOpnXr1iY7O9ssWrTotM9bfn6+ueKKK4zL5TJJSUkmOzvbt33r1q3msssuM61btzaSTElJiTHGmG3btpnrr7/eJCUlmdatW5sLL7zQ3HPPPaampsYYY8yhQ4fM6NGjTWxsrElJSTGPPvpovTWGm/revxUrVpi9e/dG9HtY1/fFrFmzTNeuXX3LFRUV5qabbjKJiYnG4/GYhQsX1jmYbPHixb7lkpISI8ls3LjRt+7U786Trz1v3jzTsWNH06pVK/Ov//qv5ttvv/Uds2HDBtO/f3/TqlUr06NHD/Pqq6/W+T3pL9w9K8LV1NSoV69eGjFiBFcjAwALMZgswvzjH//Qe++9p0GDBqmqqkpPPfWUSkpKNGrUqFCXBgCoQ/if7EGTREVFaeHChbrkkks0cOBAbd68WR988IF69eoV6tIAAHWg6xsAAIvRogYAwGIENQAAFiOoAQCwGEENAIDFCGoAACxGUAPNiDFGv/rVr5ScnCzHcVRQUBCSOnbs2BHS1weaE6ZnAc3IO++8o+uuu075+fm64IILdN5556lFi8Be12jMmDE6ePCglixZ4ltXXV2tb775JiivDzR3/A8CmpHi4mJ16NBBV1xxRUjriI6O9t3AAMC5oesbaCbGjBmju+++W6WlpXIcR2lpaUpLS9Ps2bNr7de3b19Nnz7dt+w4jv7yl7/o+uuvV2xsrHr06KG33nqr1jGFhYW69tprlZiYqISEBP3kJz9RcXGxpk+frueff15vvvmmHMeR4zjKz8+vs+t7xYoVGjBggFwulzp06KD77rtPx48f920fPHiwJk2apN/+9rdKTk5WampqrTqBSEVQA83EnDlz9OCDD6pz584qKyurdd/xM5kxY4ZGjBihzz77TNdcc41uueUWffvtt5Kk3bt366c//alcLpeWLVum9evXa+zYsTp+/LimTJmiESNG6KqrrlJZWZnKysrqbM3v3r1b11xzjS655BJt2rRJTz/9tJ555hk99NBDtfZ7/vnnFRcXp7Vr1+rRRx/Vgw8+qPfff//cfjFAmKPrG2gm3G63EhISzqrbecyYMRo5cqQk6eGHH9YTTzyhTz75RFdddZXmzp0rt9utl19+WS1btpQkpaen+45t3bq1qqqqGnzNefPmyePx6KmnnpLjOLrwwgv19ddf695779UDDzzguz9xnz59NG3aNEkn7vn71FNP6cMPP9TQoUOb9PMAzQktagDq06eP799xcXFKTEzU3r17JUkFBQX6yU9+4gvps7FlyxZdfvnlchzHt27gwIGqrKzUrl276qxDkjp06OCrA4hUBDXQjEVFRenUiR3Hjh07bb9TQ9hxHNXU1Eg60WIOlobqACIVQQ00Y+3atVNZWZlv2ev1qqSkpEnP0adPH61atarOgJekmJgYVVdXN/gcvXr10scff1zrj4Y1a9YoISFBnTt3blI9QKQhqIFmbMiQIcrLy9OqVau0efNm5eTkKDo6uknPcdddd8nr9ermm2/WunXrtH37duXl5Wnr1q2SpLS0NH322WfaunWr9u3bV2egT5gwQTt37tTdd9+tL7/8Um+++aamTZumyZMn+85PA6gb/0OAZiw3N1eDBg3Stddeq1/84hcaPny4unXr1qTnaNu2rZYtW6bKykoNGjRImZmZ+u///m9fN/Xtt9+unj17qn///mrXrp3WrFlz2nN06tRJb7/9tj755BNlZGRo/PjxGjdunO6//36//JxAc8aVyQAAsBgtagAALEZQAwBgMYIaAACLEdQAAFiMoAYAwGIENQAAFiOoAQCwGEENAIDFCGoAACxGUAMAYDGCGgAAixHUAABY7P8B0daK/soSNGUAAAAASUVORK5CYII=",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import seaborn as sns\n",
"\n",
"df_melt = pd.melt(df, var_name=\"function\", value_name=\"time\")\n",
"_ = sns.catplot(data=df_melt, x=\"function\", y=\"time\", kind=\"swarm\")"
]
},
{
"cell_type": "markdown",
"id": "83eab582",
"metadata": {},
"source": [
"We see that the numba version is in fact faster, and the other two are about the same. It isn't significantly faster through,\n",
"so we might want to run a profiler on the original function to see where most of the time is spent:\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "06d7777a",
"metadata": {},
"outputs": [],
"source": [
"%load_ext line_profiler"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f88942d6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Timer unit: 1e-09 s\n",
"\n",
"Total time: 1.41607 s\n",
"File: /var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/egglog-9e61d62c-d17d-495b-b8db-f1eb3b38dcbb.py\n",
"Function: __fn at line 1\n",
"\n",
"Line # Hits Time Per Hit % Time Line Contents\n",
"==============================================================\n",
" 1 def __fn(X, y):\n",
" 2 1 13000.0 13000.0 0.0 assert X.dtype == np.dtype(np.float64)\n",
" 3 1 2000.0 2000.0 0.0 assert X.shape == (1000000, 20,)\n",
" 4 1 23813000.0 2e+07 1.7 assert np.all(np.isfinite(X))\n",
" 5 1 11000.0 11000.0 0.0 assert y.dtype == np.dtype(np.int64)\n",
" 6 1 14000.0 14000.0 0.0 assert y.shape == (1000000,)\n",
" 7 1 23226000.0 2e+07 1.6 assert set(np.unique(y)) == set((0, 1,))\n",
" 8 1 542000.0 542000.0 0.0 _0 = y == np.array(0)\n",
" 9 1 488000.0 488000.0 0.0 _1 = np.sum(_0)\n",
" 10 1 493000.0 493000.0 0.0 _2 = y == np.array(1)\n",
" 11 1 454000.0 454000.0 0.0 _3 = np.sum(_2)\n",
" 12 1 14000.0 14000.0 0.0 _4 = np.array((_1, _3,)).astype(np.dtype(np.float64))\n",
" 13 1 9000.0 9000.0 0.0 _5 = _4 / np.array(1000000.0)\n",
" 14 1 4000.0 4000.0 0.0 _6 = np.zeros((2, 20,), dtype=np.dtype(np.float64))\n",
" 15 1 98376000.0 1e+08 6.9 _7 = np.sum(X[_0], axis=0)\n",
" 16 1 38374000.0 4e+07 2.7 _8 = _7 / np.array(X[_0].shape[0])\n",
" 17 1 6000.0 6000.0 0.0 _6[0, :] = _8\n",
" 18 1 45697000.0 5e+07 3.2 _9 = np.sum(X[_2], axis=0)\n",
" 19 1 35522000.0 4e+07 2.5 _10 = _9 / np.array(X[_2].shape[0])\n",
" 20 1 6000.0 6000.0 0.0 _6[1, :] = _10\n",
" 21 1 13000.0 13000.0 0.0 _11 = _5 @ _6\n",
" 22 1 33768000.0 3e+07 2.4 _12 = X - _11\n",
" 23 1 18000.0 18000.0 0.0 _13 = np.sqrt(np.array((1.0 / 999998)))\n",
" 24 1 50544000.0 5e+07 3.6 _14 = X[_0] - _6[0, :]\n",
" 25 1 55966000.0 6e+07 4.0 _15 = X[_2] - _6[1, :]\n",
" 26 1 26138000.0 3e+07 1.8 _16 = np.concatenate((_14, _15,), axis=0)\n",
" 27 1 23667000.0 2e+07 1.7 _17 = np.sum(_16, axis=0)\n",
" 28 1 26000.0 26000.0 0.0 _18 = _17 / np.array(_16.shape[0])\n",
" 29 1 45000.0 45000.0 0.0 _19 = np.expand_dims(_18, 0)\n",
" 30 1 33604000.0 3e+07 2.4 _20 = _16 - _19\n",
" 31 1 24774000.0 2e+07 1.7 _21 = np.square(_20)\n",
" 32 1 21671000.0 2e+07 1.5 _22 = np.sum(_21, axis=0)\n",
" 33 1 31000.0 31000.0 0.0 _23 = _22 / np.array(_21.shape[0])\n",
" 34 1 4000.0 4000.0 0.0 _24 = np.sqrt(_23)\n",
" 35 1 7000.0 7000.0 0.0 _25 = _24 == np.array(0)\n",
" 36 1 3000.0 3000.0 0.0 _24[_25] = np.array(1.0)\n",
" 37 1 32910000.0 3e+07 2.3 _26 = _16 / _24\n",
" 38 1 24105000.0 2e+07 1.7 _27 = _13 * _26\n",
" 39 1 814200000.0 8e+08 57.5 _28 = np.linalg.svd(_27, full_matrices=False)\n",
" 40 1 23000.0 23000.0 0.0 _29 = _28[1] > np.array(0.0001)\n",
" 41 1 10000.0 10000.0 0.0 _30 = _29.astype(np.dtype(np.int32))\n",
" 42 1 63000.0 63000.0 0.0 _31 = np.sum(_30)\n",
" 43 1 14000.0 14000.0 0.0 _32 = _28[2][:_31, :] / _24\n",
" 44 1 7000.0 7000.0 0.0 _33 = _32.T / _28[1][:_31]\n",
" 45 1 9000.0 9000.0 0.0 _34 = np.array(1000000) * _5\n",
" 46 1 4000.0 4000.0 0.0 _35 = _34 * np.array(1.0)\n",
" 47 1 3000.0 3000.0 0.0 _36 = np.sqrt(_35)\n",
" 48 1 5000.0 5000.0 0.0 _37 = _6 - _11\n",
" 49 1 4000.0 4000.0 0.0 _38 = _36 * _37.T\n",
" 50 1 11000.0 11000.0 0.0 _39 = _38.T @ _33\n",
" 51 1 70000.0 70000.0 0.0 _40 = np.linalg.svd(_39, full_matrices=False)\n",
" 52 1 6000.0 6000.0 0.0 _41 = np.array(0.0001) * _40[1][0]\n",
" 53 1 3000.0 3000.0 0.0 _42 = _40[1] > _41\n",
" 54 1 4000.0 4000.0 0.0 _43 = _42.astype(np.dtype(np.int32))\n",
" 55 1 18000.0 18000.0 0.0 _44 = np.sum(_43)\n",
" 56 1 8000.0 8000.0 0.0 _45 = _33 @ _40[2].T[:, :_44]\n",
" 57 1 7242000.0 7e+06 0.5 _46 = _12 @ _45\n",
" 58 1 7000.0 7000.0 0.0 return _46[:, :1]"
]
}
],
"source": [
"%lprun -f fn fn(X_np, y_np)"
]
},
{
"cell_type": "markdown",
"id": "7bf27fb6",
"metadata": {},
"source": [
"We see that most of the time is spent in the SVD funciton, which [wouldn't be improved much by numba](https://github.com/numba/numba/issues/2423)\n",
"since it is will call out to LAPACK, just like NumPy. The only savings would come from the other parts of the progarm,\n",
"which can be inlined into\n",
"\n",
"## Conclusion\n",
"\n",
"To recap, in this tutorial we:\n",
"\n",
"1. Tried using a normal scikit-learn LDA function on some test data.\n",
"2. Built up an abstract array and called it with that instead\n",
"3. Optimized it and translated it to work with Numba\n",
"4. Compiled it to a standalone Python funciton, which was optimized with Numba\n",
"5. Verified that this improved our performance with this test data.\n",
"\n",
"The implementation of the Array API provided here is experimental, and not complete, but at least serves to show it is\n",
"possible to build an API like that with `egglog`.\n"
]
}
],
"metadata": {
"file_format": "mystnb",
"kernelspec": {
"display_name": "egglog-python",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"mystnb": {
"execution_mode": "off"
}
},
"nbformat": 4,
"nbformat_minor": 5
}