_NDArray_1 = NDArray.var("X")\n",
"assume_dtype(_NDArray_1, DType.float64)\n",
"assume_shape(_NDArray_1, TupleInt.from_vec(Vec[Int](Int(150), Int(4))))\n",
"assume_isfinite(_NDArray_1)\n",
"_NDArray_2 = NDArray.var("y")\n",
"assume_dtype(_NDArray_2, DType.int64)\n",
"assume_shape(_NDArray_2, TupleInt.from_vec(Vec[Int](Int(150))))\n",
"_TupleValue_1 = TupleValue.from_vec(Vec[Value](Value.int(Int(0)), Value.int(Int(1)), Value.int(Int(2))))\n",
"assume_value_one_of(_NDArray_2, _TupleValue_1)\n",
"_NDArray_3 = zeros(\n",
" TupleInt.from_vec(Vec[Int](NDArray.vector(_TupleValue_1).shape[Int(0)], asarray(_NDArray_1).shape[Int(1)])),\n",
" OptionalDType.some(asarray(_NDArray_1).dtype),\n",
" OptionalDevice.some(asarray(_NDArray_1).device),\n",
")\n",
"_MultiAxisIndexKeyItem_1 = MultiAxisIndexKeyItem.slice(Slice())\n",
"_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1)))\n",
"_IndexKey_2 = IndexKey.ndarray(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(0))))\n",
"_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))\n",
"_NDArray_3[_IndexKey_1] = mean(asarray(_NDArray_1)[_IndexKey_2], _OptionalIntOrTuple_1)\n",
"_IndexKey_3 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(1)), _MultiAxisIndexKeyItem_1)))\n",
"_IndexKey_4 = IndexKey.ndarray(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(1))))\n",
"_NDArray_3[_IndexKey_3] = mean(asarray(_NDArray_1)[_IndexKey_4], _OptionalIntOrTuple_1)\n",
"_IndexKey_5 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1)))\n",
"_IndexKey_6 = IndexKey.ndarray(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(2))))\n",
"_NDArray_3[_IndexKey_5] = mean(asarray(_NDArray_1)[_IndexKey_6], _OptionalIntOrTuple_1)\n",
"_NDArray_4 = zeros(TupleInt.from_vec(Vec[Int](Int(3), Int(4))), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))\n",
"_IndexKey_7 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1)))\n",
"_NDArray_4[_IndexKey_7] = mean(_NDArray_1[_IndexKey_2], _OptionalIntOrTuple_1)\n",
"_IndexKey_8 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(1)), _MultiAxisIndexKeyItem_1)))\n",
"_NDArray_4[_IndexKey_8] = mean(_NDArray_1[_IndexKey_4], _OptionalIntOrTuple_1)\n",
"_IndexKey_9 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1)))\n",
"_NDArray_4[_IndexKey_9] = mean(_NDArray_1[_IndexKey_6], _OptionalIntOrTuple_1)\n",
"_NDArray_5 = concat(\n",
" TupleNDArray.from_vec(\n",
" Vec[NDArray](\n",
" _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(0))))] - _NDArray_4[_IndexKey_7],\n",
" _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(1))))] - _NDArray_4[_IndexKey_8],\n",
" _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(2))))] - _NDArray_4[_IndexKey_9],\n",
" )\n",
" ),\n",
" OptionalInt.some(Int(0)),\n",
")\n",
"_NDArray_6 = std(_NDArray_5, _OptionalIntOrTuple_1)\n",
"_NDArray_6[IndexKey.ndarray(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(\n",
" Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))))\n",
")\n",
"_TupleNDArray_1 = svd(\n",
" sqrt(\n",
" asarray(\n",
" NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))),\n",
" OptionalDType.some(DType.float64),\n",
" OptionalBool.none,\n",
" OptionalDevice.some(_NDArray_1.device),\n",
" )\n",
" )\n",
" * (_NDArray_5 / _NDArray_6),\n",
" Boolean(False),\n",
")\n",
"_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))\n",
"_NDArray_7 = asarray(reshape(asarray(_NDArray_2), TupleInt.from_vec(Vec[Int](Int(-1)))))\n",
"_NDArray_8 = unique_values(concat(TupleNDArray.from_vec(Vec[NDArray](unique_values(asarray(_NDArray_7))))))\n",
"_NDArray_9 = std(\n",
" concat(\n",
" TupleNDArray.from_vec(\n",
" Vec[NDArray](\n",
" asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_7 == _NDArray_8[IndexKey.int(Int(0))])] - _NDArray_3[_IndexKey_1],\n",
" asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_7 == _NDArray_8[IndexKey.int(Int(1))])] - _NDArray_3[_IndexKey_3],\n",
" asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_7 == _NDArray_8[IndexKey.int(Int(2))])] - _NDArray_3[_IndexKey_5],\n",
" )\n",
" ),\n",
" OptionalInt.some(Int(0)),\n",
" ),\n",
" _OptionalIntOrTuple_1,\n",
")\n",
"_NDArray_10 = copy(_NDArray_9)\n",
"_NDArray_10[IndexKey.ndarray(_NDArray_9 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))\n",
"_NDArray_11 = astype(unique_counts(_NDArray_2)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("150"), BigInt.from_string("1")))))\n",
"_TupleNDArray_2 = svd(\n",
" (\n",
" sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_11) * NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("2"))))))\n",
" * (_NDArray_4 - (_NDArray_11 @ _NDArray_4)).T\n",
" ).T\n",
" @ (\n",
" (\n",
" _TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))]\n",
" / _NDArray_6\n",
" ).T\n",
" / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)]\n",
" ),\n",
" Boolean(False),\n",
")\n",
"(\n",
" (asarray(_NDArray_1) - ((astype(unique_counts(_NDArray_2)[Int(1)], asarray(_NDArray_1).dtype) / NDArray.scalar(Value.float(Float(150.0)))) @ _NDArray_3))\n",
" @ (\n",
" (\n",
" (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))] / _NDArray_10).T\n",
" / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)]\n",
" )\n",
" @ _TupleNDArray_2[Int(2)].T[\n",
" IndexKey.multi_axis(\n",
" MultiAxisIndexKey.from_vec(\n",
" Vec(\n",
" _MultiAxisIndexKeyItem_1,\n",
" MultiAxisIndexKeyItem.slice(\n",
" Slice(\n",
" OptionalInt.none,\n",
" OptionalInt.some(\n",
" sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))\n",
" .to_value()\n",
" .to_int\n",
" ),\n",
" )\n",
" ),\n",
" )\n",
" )\n",
" )\n",
" ]\n",
" )\n",
")[IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(_MultiAxisIndexKeyItem_1, MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(2)))))))]\n",
" \n"
],
"text/latex": [
"\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}1} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{X}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}dtype}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}shape}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{]}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{150}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}isfinite}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{y}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}dtype}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int64}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}shape}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{]}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{150}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}TupleValue\\PYZus{}1} \\PY{o}{=} \\PY{n}{TupleValue}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{Value}\\PY{p}{]}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}value\\PYZus{}one\\PYZus{}of}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{\\PYZus{}TupleValue\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{=} \\PY{n}{zeros}\\PY{p}{(}\n",
" \\PY{n}{TupleInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{]}\\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{vector}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleValue\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{,} \\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{OptionalDType}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{dtype}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{OptionalDevice}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{device}\\PY{p}{)}\\PY{p}{,}\n",
"\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1} \\PY{o}{=} \\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{Slice}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{ndarray}\\PY{p}{(}\\PY{n}{unique\\PYZus{}inverse}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1} \\PY{o}{=} \\PY{n}{OptionalIntOrTuple}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{IntOrTuple}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1}\\PY{p}{]} \\PY{o}{=} \\PY{n}{mean}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2}\\PY{p}{]}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}3} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}4} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{ndarray}\\PY{p}{(}\\PY{n}{unique\\PYZus{}inverse}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}3}\\PY{p}{]} \\PY{o}{=} \\PY{n}{mean}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}4}\\PY{p}{]}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}5} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}6} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{ndarray}\\PY{p}{(}\\PY{n}{unique\\PYZus{}inverse}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}5}\\PY{p}{]} \\PY{o}{=} \\PY{n}{mean}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}6}\\PY{p}{]}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{=} \\PY{n}{zeros}\\PY{p}{(}\\PY{n}{TupleInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{]}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{3}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{OptionalDType}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)}\\PY{p}{,} \\PY{n}{OptionalDevice}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{o}{.}\\PY{n}{device}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}7} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{MultiAxisIndexKeyItem}\\PY{p}{]}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}7}\\PY{p}{]} \\PY{o}{=} \\PY{n}{mean}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2}\\PY{p}{]}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}8} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{MultiAxisIndexKeyItem}\\PY{p}{]}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}8}\\PY{p}{]} \\PY{o}{=} \\PY{n}{mean}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}4}\\PY{p}{]}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}9} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{MultiAxisIndexKeyItem}\\PY{p}{]}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}9}\\PY{p}{]} \\PY{o}{=} \\PY{n}{mean}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}6}\\PY{p}{]}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}5} \\PY{o}{=} \\PY{n}{concat}\\PY{p}{(}\n",
" \\PY{n}{TupleNDArray}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\n",
" \\PY{n}{Vec}\\PY{p}{[}\\PY{n}{NDArray}\\PY{p}{]}\\PY{p}{(}\n",
" \\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{ndarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}7}\\PY{p}{]}\\PY{p}{,}\n",
" \\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{ndarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}8}\\PY{p}{]}\\PY{p}{,}\n",
" \\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{ndarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}9}\\PY{p}{]}\\PY{p}{,}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,}\n",
"\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}6} \\PY{o}{=} \\PY{n}{std}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{ndarray}\\PY{p}{(}\\PY{n}{std}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\n",
" \\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{o}{.}\\PY{n}{rational}\\PY{p}{(}\\PY{n}{BigRat}\\PY{p}{(}\\PY{n}{BigInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}string}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{1}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{BigInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}string}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{1}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1} \\PY{o}{=} \\PY{n}{svd}\\PY{p}{(}\n",
" \\PY{n}{sqrt}\\PY{p}{(}\n",
" \\PY{n}{asarray}\\PY{p}{(}\n",
" \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{o}{.}\\PY{n}{rational}\\PY{p}{(}\\PY{n}{BigRat}\\PY{p}{(}\\PY{n}{BigInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}string}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{1}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{BigInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}string}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{147}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{OptionalDType}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{OptionalBool}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,}\n",
" \\PY{n}{OptionalDevice}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{o}{.}\\PY{n}{device}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{o}{*} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5} \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{Boolean}\\PY{p}{(}\\PY{k+kc}{False}\\PY{p}{)}\\PY{p}{,}\n",
"\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}Slice\\PYZus{}1} \\PY{o}{=} \\PY{n}{Slice}\\PY{p}{(}\\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZgt{}} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{0.0001}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int32}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}int}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{=} \\PY{n}{asarray}\\PY{p}{(}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{)}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{]}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}8} \\PY{o}{=} \\PY{n}{unique\\PYZus{}values}\\PY{p}{(}\\PY{n}{concat}\\PY{p}{(}\\PY{n}{TupleNDArray}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{NDArray}\\PY{p}{]}\\PY{p}{(}\\PY{n}{unique\\PYZus{}values}\\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}9} \\PY{o}{=} \\PY{n}{std}\\PY{p}{(}\n",
" \\PY{n}{concat}\\PY{p}{(}\n",
" \\PY{n}{TupleNDArray}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\n",
" \\PY{n}{Vec}\\PY{p}{[}\\PY{n}{NDArray}\\PY{p}{]}\\PY{p}{(}\n",
" \\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{ndarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{==} \\PY{n}{\\PYZus{}NDArray\\PYZus{}8}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1}\\PY{p}{]}\\PY{p}{,}\n",
" \\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{ndarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{==} \\PY{n}{\\PYZus{}NDArray\\PYZus{}8}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}3}\\PY{p}{]}\\PY{p}{,}\n",
" \\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{ndarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{==} \\PY{n}{\\PYZus{}NDArray\\PYZus{}8}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}5}\\PY{p}{]}\\PY{p}{,}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{,}\n",
"\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}10} \\PY{o}{=} \\PY{n}{copy}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}9}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}10}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{ndarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}9} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{1.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}11} \\PY{o}{=} \\PY{n}{astype}\\PY{p}{(}\\PY{n}{unique\\PYZus{}counts}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{o}{.}\\PY{n}{rational}\\PY{p}{(}\\PY{n}{BigRat}\\PY{p}{(}\\PY{n}{BigInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}string}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{150}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{BigInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}string}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{1}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2} \\PY{o}{=} \\PY{n}{svd}\\PY{p}{(}\n",
" \\PY{p}{(}\n",
" \\PY{n}{sqrt}\\PY{p}{(}\\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{150}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{\\PYZus{}NDArray\\PYZus{}11}\\PY{p}{)} \\PY{o}{*} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{o}{.}\\PY{n}{rational}\\PY{p}{(}\\PY{n}{BigRat}\\PY{p}{(}\\PY{n}{BigInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}string}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{1}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{BigInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}string}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{2}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
" \\PY{o}{*} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{\\PYZhy{}} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}11} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T}\n",
" \\PY{p}{)}\\PY{o}{.}\\PY{n}{T}\n",
" \\PY{o}{@} \\PY{p}{(}\n",
" \\PY{p}{(}\n",
" \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{MultiAxisIndexKeyItem}\\PY{p}{]}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{,} \\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\n",
" \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\n",
" \\PY{p}{)}\\PY{o}{.}\\PY{n}{T}\n",
" \\PY{o}{/} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{]}\n",
" \\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{Boolean}\\PY{p}{(}\\PY{k+kc}{False}\\PY{p}{)}\\PY{p}{,}\n",
"\\PY{p}{)}\n",
"\\PY{p}{(}\n",
" \\PY{p}{(}\\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)} \\PY{o}{\\PYZhy{}} \\PY{p}{(}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{unique\\PYZus{}counts}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{)}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{,} \\PY{n}{asarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\\PY{o}{.}\\PY{n}{dtype}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{150.0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{)}\\PY{p}{)}\n",
" \\PY{o}{@} \\PY{p}{(}\n",
" \\PY{p}{(}\n",
" \\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{,} \\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}10}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T}\n",
" \\PY{o}{/} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{]}\n",
" \\PY{p}{)}\n",
" \\PY{o}{@} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{o}{.}\\PY{n}{T}\\PY{p}{[}\n",
" \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\n",
" \\PY{n}{MultiAxisIndexKey}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\n",
" \\PY{n}{Vec}\\PY{p}{(}\n",
" \\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1}\\PY{p}{,}\n",
" \\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\n",
" \\PY{n}{Slice}\\PY{p}{(}\n",
" \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,}\n",
" \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\n",
" \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZgt{}} \\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{0.0001}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int32}\\PY{p}{)}\\PY{p}{)}\n",
" \\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\n",
" \\PY{o}{.}\\PY{n}{to\\PYZus{}int}\n",
" \\PY{p}{)}\\PY{p}{,}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\\PY{p}{,}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{p}{]}\n",
" \\PY{p}{)}\n",
"\\PY{p}{)}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{(}\\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1}\\PY{p}{,} \\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{Slice}\\PY{p}{(}\\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\n",
"\\end{Verbatim}\n"
],
"text/plain": [
"_NDArray_1 = NDArray.var(\"X\")\n",
"assume_dtype(_NDArray_1, DType.float64)\n",
"assume_shape(_NDArray_1, TupleInt.from_vec(Vec[Int](Int(150), Int(4))))\n",
"assume_isfinite(_NDArray_1)\n",
"_NDArray_2 = NDArray.var(\"y\")\n",
"assume_dtype(_NDArray_2, DType.int64)\n",
"assume_shape(_NDArray_2, TupleInt.from_vec(Vec[Int](Int(150))))\n",
"_TupleValue_1 = TupleValue.from_vec(Vec[Value](Value.int(Int(0)), Value.int(Int(1)), Value.int(Int(2))))\n",
"assume_value_one_of(_NDArray_2, _TupleValue_1)\n",
"_NDArray_3 = zeros(\n",
" TupleInt.from_vec(Vec[Int](NDArray.vector(_TupleValue_1).shape[Int(0)], asarray(_NDArray_1).shape[Int(1)])),\n",
" OptionalDType.some(asarray(_NDArray_1).dtype),\n",
" OptionalDevice.some(asarray(_NDArray_1).device),\n",
")\n",
"_MultiAxisIndexKeyItem_1 = MultiAxisIndexKeyItem.slice(Slice())\n",
"_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1)))\n",
"_IndexKey_2 = IndexKey.ndarray(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(0))))\n",
"_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))\n",
"_NDArray_3[_IndexKey_1] = mean(asarray(_NDArray_1)[_IndexKey_2], _OptionalIntOrTuple_1)\n",
"_IndexKey_3 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(1)), _MultiAxisIndexKeyItem_1)))\n",
"_IndexKey_4 = IndexKey.ndarray(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(1))))\n",
"_NDArray_3[_IndexKey_3] = mean(asarray(_NDArray_1)[_IndexKey_4], _OptionalIntOrTuple_1)\n",
"_IndexKey_5 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1)))\n",
"_IndexKey_6 = IndexKey.ndarray(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(2))))\n",
"_NDArray_3[_IndexKey_5] = mean(asarray(_NDArray_1)[_IndexKey_6], _OptionalIntOrTuple_1)\n",
"_NDArray_4 = zeros(TupleInt.from_vec(Vec[Int](Int(3), Int(4))), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))\n",
"_IndexKey_7 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1)))\n",
"_NDArray_4[_IndexKey_7] = mean(_NDArray_1[_IndexKey_2], _OptionalIntOrTuple_1)\n",
"_IndexKey_8 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(1)), _MultiAxisIndexKeyItem_1)))\n",
"_NDArray_4[_IndexKey_8] = mean(_NDArray_1[_IndexKey_4], _OptionalIntOrTuple_1)\n",
"_IndexKey_9 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1)))\n",
"_NDArray_4[_IndexKey_9] = mean(_NDArray_1[_IndexKey_6], _OptionalIntOrTuple_1)\n",
"_NDArray_5 = concat(\n",
" TupleNDArray.from_vec(\n",
" Vec[NDArray](\n",
" _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(0))))] - _NDArray_4[_IndexKey_7],\n",
" _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(1))))] - _NDArray_4[_IndexKey_8],\n",
" _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(2))))] - _NDArray_4[_IndexKey_9],\n",
" )\n",
" ),\n",
" OptionalInt.some(Int(0)),\n",
")\n",
"_NDArray_6 = std(_NDArray_5, _OptionalIntOrTuple_1)\n",
"_NDArray_6[IndexKey.ndarray(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(\n",
" Value.float(Float.rational(BigRat(BigInt.from_string(\"1\"), BigInt.from_string(\"1\"))))\n",
")\n",
"_TupleNDArray_1 = svd(\n",
" sqrt(\n",
" asarray(\n",
" NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string(\"1\"), BigInt.from_string(\"147\"))))),\n",
" OptionalDType.some(DType.float64),\n",
" OptionalBool.none,\n",
" OptionalDevice.some(_NDArray_1.device),\n",
" )\n",
" )\n",
" * (_NDArray_5 / _NDArray_6),\n",
" Boolean(False),\n",
")\n",
"_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))\n",
"_NDArray_7 = asarray(reshape(asarray(_NDArray_2), TupleInt.from_vec(Vec[Int](Int(-1)))))\n",
"_NDArray_8 = unique_values(concat(TupleNDArray.from_vec(Vec[NDArray](unique_values(asarray(_NDArray_7))))))\n",
"_NDArray_9 = std(\n",
" concat(\n",
" TupleNDArray.from_vec(\n",
" Vec[NDArray](\n",
" asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_7 == _NDArray_8[IndexKey.int(Int(0))])] - _NDArray_3[_IndexKey_1],\n",
" asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_7 == _NDArray_8[IndexKey.int(Int(1))])] - _NDArray_3[_IndexKey_3],\n",
" asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_7 == _NDArray_8[IndexKey.int(Int(2))])] - _NDArray_3[_IndexKey_5],\n",
" )\n",
" ),\n",
" OptionalInt.some(Int(0)),\n",
" ),\n",
" _OptionalIntOrTuple_1,\n",
")\n",
"_NDArray_10 = copy(_NDArray_9)\n",
"_NDArray_10[IndexKey.ndarray(_NDArray_9 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))\n",
"_NDArray_11 = astype(unique_counts(_NDArray_2)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string(\"150\"), BigInt.from_string(\"1\")))))\n",
"_TupleNDArray_2 = svd(\n",
" (\n",
" sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_11) * NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string(\"1\"), BigInt.from_string(\"2\"))))))\n",
" * (_NDArray_4 - (_NDArray_11 @ _NDArray_4)).T\n",
" ).T\n",
" @ (\n",
" (\n",
" _TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))]\n",
" / _NDArray_6\n",
" ).T\n",
" / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)]\n",
" ),\n",
" Boolean(False),\n",
")\n",
"(\n",
" (asarray(_NDArray_1) - ((astype(unique_counts(_NDArray_2)[Int(1)], asarray(_NDArray_1).dtype) / NDArray.scalar(Value.float(Float(150.0)))) @ _NDArray_3))\n",
" @ (\n",
" (\n",
" (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))] / _NDArray_10).T\n",
" / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)]\n",
" )\n",
" @ _TupleNDArray_2[Int(2)].T[\n",
" IndexKey.multi_axis(\n",
" MultiAxisIndexKey.from_vec(\n",
" Vec(\n",
" _MultiAxisIndexKeyItem_1,\n",
" MultiAxisIndexKeyItem.slice(\n",
" Slice(\n",
" OptionalInt.none,\n",
" OptionalInt.some(\n",
" sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))\n",
" .to_value()\n",
" .to_int\n",
" ),\n",
" )\n",
" ),\n",
" )\n",
" )\n",
" )\n",
" ]\n",
" )\n",
")[IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(_MultiAxisIndexKeyItem_1, MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(2)))))))]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"res"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to run this, scikit-learn treated these objects as \"array like\", meaning they conformed to [the Array API](https://data-apis.org/array-api/latest/).\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conversions: From Python to egglog\n",
"\n",
"_Use conversions if you want your egglog API to be called with existing Python objects, without manually upcasting them_\n",
"\n",
"_We will see this in our example:_\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"class LinearDiscriminantAnalysis:\n",
" ...\n",
" def fit(self, X, y):\n",
" ...\n",
" _, cnts = xp.unique_counts(y) # non-negative ints\n",
" self.priors_ = xp.astype(cnts, X.dtype) / float(y.shape[0])\n",
"```\n",
"\n",
"Ends up resulting in this expression:\n",
"\n",
"```python\n",
"astype(unique_counts(_NDArray_3)[Int(1)], asarray(_NDArray_1).dtype) / NDArray.scalar(Value.float(Float(1000000.0)))\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"How?\n",
"\n",
"We have exposed a global conversion logic, where if you pass an arg to egglog and it isn't the correct type, it will try to upcast the arg to the required egglog type.\n",
"\n",
"There is a graph of all conversions and it will find the shortest path from the input to the desired type and automatically upcast to that.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example\n",
"\n",
"For example, in indexing, if we do a slice (i.e. `1:10:2`), we convert this to our custom egglog `Slice` expressions:\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class Slice(Expr):\n",
" def __init__(\n",
" self,\n",
" start: OptionalInt = OptionalInt.none,\n",
" stop: OptionalInt = OptionalInt.none,\n",
" step: OptionalInt = OptionalInt.none,\n",
" ) -> None: ...\n",
"\n",
"\n",
"converter(\n",
" slice,\n",
" Slice,\n",
" lambda x: Slice(\n",
" convert(x.start, OptionalInt),\n",
" convert(x.stop, OptionalInt),\n",
" convert(x.step, OptionalInt),\n",
" ),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class A(Expr):\n",
" def __init__(self) -> None: ...\n",
" def __getitem__(self, s: Slice) -> Int: ..."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"_NDArray_1 = NDArray.var("X")\n",
"assume_dtype(_NDArray_1, DType.float64)\n",
"assume_shape(_NDArray_1, TupleInt.from_vec(Vec[Int](Int(150), Int(4))))\n",
"assume_isfinite(_NDArray_1)\n",
"_NDArray_2 = NDArray.var("y")\n",
"assume_dtype(_NDArray_2, DType.int64)\n",
"assume_shape(_NDArray_2, TupleInt.from_vec(Vec[Int](Int(150))))\n",
"assume_value_one_of(_NDArray_2, TupleValue.from_vec(Vec[Value](Value.int(Int(0)), Value.int(Int(1)), Value.int(Int(2)))))\n",
"_NDArray_3 = astype(\n",
" NDArray.vector(\n",
" TupleValue.from_vec(\n",
" Vec[Value](\n",
" sum(_NDArray_2 == NDArray.scalar(Value.int(Int(0)))).to_value(),\n",
" sum(_NDArray_2 == NDArray.scalar(Value.int(Int(1)))).to_value(),\n",
" sum(_NDArray_2 == NDArray.scalar(Value.int(Int(2)))).to_value(),\n",
" )\n",
" )\n",
" ),\n",
" DType.float64,\n",
") / NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("150"), BigInt.from_string("1")))))\n",
"_NDArray_4 = zeros(TupleInt.from_vec(Vec[Int](Int(3), Int(4))), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))\n",
"_MultiAxisIndexKeyItem_1 = MultiAxisIndexKeyItem.slice(Slice())\n",
"_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1)))\n",
"_NDArray_5 = _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(0))))]\n",
"_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))\n",
"_NDArray_4[_IndexKey_1] = sum(_NDArray_5, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_5.shape[Int(0)]))\n",
"_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(1)), _MultiAxisIndexKeyItem_1)))\n",
"_NDArray_6 = _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(1))))]\n",
"_NDArray_4[_IndexKey_2] = sum(_NDArray_6, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_6.shape[Int(0)]))\n",
"_IndexKey_3 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1)))\n",
"_NDArray_7 = _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(2))))]\n",
"_NDArray_4[_IndexKey_3] = sum(_NDArray_7, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_7.shape[Int(0)]))\n",
"_NDArray_8 = concat(\n",
" TupleNDArray.from_vec(Vec[NDArray](_NDArray_5 - _NDArray_4[_IndexKey_1], _NDArray_6 - _NDArray_4[_IndexKey_2], _NDArray_7 - _NDArray_4[_IndexKey_3])), OptionalInt.some(Int(0))\n",
")\n",
"_NDArray_9 = square(_NDArray_8 - expand_dims(sum(_NDArray_8, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_8.shape[Int(0)]))))\n",
"_NDArray_10 = sqrt(sum(_NDArray_9, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_9.shape[Int(0)])))\n",
"_NDArray_11 = copy(_NDArray_10)\n",
"_NDArray_11[IndexKey.ndarray(_NDArray_10 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(\n",
" Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))))\n",
")\n",
"_TupleNDArray_1 = svd(\n",
" sqrt(\n",
" asarray(\n",
" NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))),\n",
" OptionalDType.some(DType.float64),\n",
" OptionalBool.none,\n",
" OptionalDevice.some(_NDArray_1.device),\n",
" )\n",
" )\n",
" * (_NDArray_8 / _NDArray_11),\n",
" Boolean(False),\n",
")\n",
"_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))\n",
"_NDArray_12 = (\n",
" _TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))]\n",
" / _NDArray_11\n",
").T / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)]\n",
"_TupleNDArray_2 = svd(\n",
" (\n",
" sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_3) * NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("2"))))))\n",
" * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T\n",
" ).T\n",
" @ _NDArray_12,\n",
" Boolean(False),\n",
")\n",
"(\n",
" (_NDArray_1 - (_NDArray_3 @ _NDArray_4))\n",
" @ (\n",
" _NDArray_12\n",
" @ _TupleNDArray_2[Int(2)].T[\n",
" IndexKey.multi_axis(\n",
" MultiAxisIndexKey.from_vec(\n",
" Vec[MultiAxisIndexKeyItem](\n",
" _MultiAxisIndexKeyItem_1,\n",
" MultiAxisIndexKeyItem.slice(\n",
" Slice(\n",
" OptionalInt.none,\n",
" OptionalInt.some(\n",
" sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))\n",
" .to_value()\n",
" .to_int\n",
" ),\n",
" )\n",
" ),\n",
" )\n",
" )\n",
" )\n",
" ]\n",
" )\n",
")[\n",
" IndexKey.multi_axis(\n",
" MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](_MultiAxisIndexKeyItem_1, MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(2))))))\n",
" )\n",
"]\n",
" \n"
],
"text/latex": [
"\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}1} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{X}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}dtype}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}shape}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{]}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{150}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}isfinite}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{var}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{y}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}dtype}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int64}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}shape}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{TupleInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{]}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{150}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{assume\\PYZus{}value\\PYZus{}one\\PYZus{}of}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2}\\PY{p}{,} \\PY{n}{TupleValue}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{Value}\\PY{p}{]}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{=} \\PY{n}{astype}\\PY{p}{(}\n",
" \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{vector}\\PY{p}{(}\n",
" \\PY{n}{TupleValue}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\n",
" \\PY{n}{Vec}\\PY{p}{[}\\PY{n}{Value}\\PY{p}{]}\\PY{p}{(}\n",
" \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{,}\n",
"\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{o}{.}\\PY{n}{rational}\\PY{p}{(}\\PY{n}{BigRat}\\PY{p}{(}\\PY{n}{BigInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}string}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{150}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{BigInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}string}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{1}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{=} \\PY{n}{zeros}\\PY{p}{(}\\PY{n}{TupleInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{]}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{3}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{OptionalDType}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)}\\PY{p}{,} \\PY{n}{OptionalDevice}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{o}{.}\\PY{n}{device}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1} \\PY{o}{=} \\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{Slice}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{MultiAxisIndexKeyItem}\\PY{p}{]}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}5} \\PY{o}{=} \\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{ndarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\n",
"\\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1} \\PY{o}{=} \\PY{n}{OptionalIntOrTuple}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{IntOrTuple}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1}\\PY{p}{]} \\PY{o}{=} \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{MultiAxisIndexKeyItem}\\PY{p}{]}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}6} \\PY{o}{=} \\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{ndarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2}\\PY{p}{]} \\PY{o}{=} \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}6}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}IndexKey\\PYZus{}3} \\PY{o}{=} \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{MultiAxisIndexKeyItem}\\PY{p}{]}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{=} \\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{ndarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}2} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}3}\\PY{p}{]} \\PY{o}{=} \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}7}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}8} \\PY{o}{=} \\PY{n}{concat}\\PY{p}{(}\n",
" \\PY{n}{TupleNDArray}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{NDArray}\\PY{p}{]}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}5} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}1}\\PY{p}{]}\\PY{p}{,} \\PY{n}{\\PYZus{}NDArray\\PYZus{}6} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}2}\\PY{p}{]}\\PY{p}{,} \\PY{n}{\\PYZus{}NDArray\\PYZus{}7} \\PY{o}{\\PYZhy{}} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{[}\\PY{n}{\\PYZus{}IndexKey\\PYZus{}3}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}9} \\PY{o}{=} \\PY{n}{square}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}8} \\PY{o}{\\PYZhy{}} \\PY{n}{expand\\PYZus{}dims}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}8}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}8}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}10} \\PY{o}{=} \\PY{n}{sqrt}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}9}\\PY{p}{,} \\PY{n}{\\PYZus{}OptionalIntOrTuple\\PYZus{}1}\\PY{p}{)} \\PY{o}{/} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}9}\\PY{o}{.}\\PY{n}{shape}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}11} \\PY{o}{=} \\PY{n}{copy}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}10}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}11}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{ndarray}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}10} \\PY{o}{==} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]} \\PY{o}{=} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\n",
" \\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{o}{.}\\PY{n}{rational}\\PY{p}{(}\\PY{n}{BigRat}\\PY{p}{(}\\PY{n}{BigInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}string}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{1}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{BigInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}string}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{1}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1} \\PY{o}{=} \\PY{n}{svd}\\PY{p}{(}\n",
" \\PY{n}{sqrt}\\PY{p}{(}\n",
" \\PY{n}{asarray}\\PY{p}{(}\n",
" \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{o}{.}\\PY{n}{rational}\\PY{p}{(}\\PY{n}{BigRat}\\PY{p}{(}\\PY{n}{BigInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}string}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{1}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{BigInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}string}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{147}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{OptionalDType}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{DType}\\PY{o}{.}\\PY{n}{float64}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{OptionalBool}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,}\n",
" \\PY{n}{OptionalDevice}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1}\\PY{o}{.}\\PY{n}{device}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{o}{*} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}8} \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}11}\\PY{p}{)}\\PY{p}{,}\n",
" \\PY{n}{Boolean}\\PY{p}{(}\\PY{k+kc}{False}\\PY{p}{)}\\PY{p}{,}\n",
"\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}Slice\\PYZus{}1} \\PY{o}{=} \\PY{n}{Slice}\\PY{p}{(}\\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZgt{}} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{0.0001}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int32}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n}{to\\PYZus{}int}\\PY{p}{)}\\PY{p}{)}\n",
"\\PY{n}{\\PYZus{}NDArray\\PYZus{}12} \\PY{o}{=} \\PY{p}{(}\n",
" \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\\PY{n}{MultiAxisIndexKey}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{MultiAxisIndexKeyItem}\\PY{p}{]}\\PY{p}{(}\\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{,} \\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\n",
" \\PY{o}{/} \\PY{n}{\\PYZus{}NDArray\\PYZus{}11}\n",
"\\PY{p}{)}\\PY{o}{.}\\PY{n}{T} \\PY{o}{/} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}1}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{\\PYZus{}Slice\\PYZus{}1}\\PY{p}{)}\\PY{p}{]}\n",
"\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2} \\PY{o}{=} \\PY{n}{svd}\\PY{p}{(}\n",
" \\PY{p}{(}\n",
" \\PY{n}{sqrt}\\PY{p}{(}\\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{150}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{\\PYZus{}NDArray\\PYZus{}3}\\PY{p}{)} \\PY{o}{*} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{o}{.}\\PY{n}{rational}\\PY{p}{(}\\PY{n}{BigRat}\\PY{p}{(}\\PY{n}{BigInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}string}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{1}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{BigInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}string}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{2}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
" \\PY{o}{*} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}4} \\PY{o}{\\PYZhy{}} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{)}\\PY{p}{)}\\PY{o}{.}\\PY{n}{T}\n",
" \\PY{p}{)}\\PY{o}{.}\\PY{n}{T}\n",
" \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}12}\\PY{p}{,}\n",
" \\PY{n}{Boolean}\\PY{p}{(}\\PY{k+kc}{False}\\PY{p}{)}\\PY{p}{,}\n",
"\\PY{p}{)}\n",
"\\PY{p}{(}\n",
" \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}1} \\PY{o}{\\PYZhy{}} \\PY{p}{(}\\PY{n}{\\PYZus{}NDArray\\PYZus{}3} \\PY{o}{@} \\PY{n}{\\PYZus{}NDArray\\PYZus{}4}\\PY{p}{)}\\PY{p}{)}\n",
" \\PY{o}{@} \\PY{p}{(}\n",
" \\PY{n}{\\PYZus{}NDArray\\PYZus{}12}\n",
" \\PY{o}{@} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{]}\\PY{o}{.}\\PY{n}{T}\\PY{p}{[}\n",
" \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\n",
" \\PY{n}{MultiAxisIndexKey}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\n",
" \\PY{n}{Vec}\\PY{p}{[}\\PY{n}{MultiAxisIndexKeyItem}\\PY{p}{]}\\PY{p}{(}\n",
" \\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1}\\PY{p}{,}\n",
" \\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\n",
" \\PY{n}{Slice}\\PY{p}{(}\n",
" \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,}\n",
" \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\n",
" \\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{astype}\\PY{p}{(}\\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]} \\PY{o}{\\PYZgt{}} \\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{scalar}\\PY{p}{(}\\PY{n}{Value}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{n}{Float}\\PY{p}{(}\\PY{l+m+mf}{0.0001}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{\\PYZus{}TupleNDArray\\PYZus{}2}\\PY{p}{[}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{)}\\PY{p}{]}\\PY{p}{[}\\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{int}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,} \\PY{n}{DType}\\PY{o}{.}\\PY{n}{int32}\\PY{p}{)}\\PY{p}{)}\n",
" \\PY{o}{.}\\PY{n}{to\\PYZus{}value}\\PY{p}{(}\\PY{p}{)}\n",
" \\PY{o}{.}\\PY{n}{to\\PYZus{}int}\n",
" \\PY{p}{)}\\PY{p}{,}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\\PY{p}{,}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{p}{)}\n",
" \\PY{p}{]}\n",
" \\PY{p}{)}\n",
"\\PY{p}{)}\\PY{p}{[}\n",
" \\PY{n}{IndexKey}\\PY{o}{.}\\PY{n}{multi\\PYZus{}axis}\\PY{p}{(}\n",
" \\PY{n}{MultiAxisIndexKey}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{[}\\PY{n}{MultiAxisIndexKeyItem}\\PY{p}{]}\\PY{p}{(}\\PY{n}{\\PYZus{}MultiAxisIndexKeyItem\\PYZus{}1}\\PY{p}{,} \\PY{n}{MultiAxisIndexKeyItem}\\PY{o}{.}\\PY{n}{slice}\\PY{p}{(}\\PY{n}{Slice}\\PY{p}{(}\\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{none}\\PY{p}{,} \\PY{n}{OptionalInt}\\PY{o}{.}\\PY{n}{some}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n",
" \\PY{p}{)}\n",
"\\PY{p}{]}\n",
"\\end{Verbatim}\n"
],
"text/plain": [
"_NDArray_1 = NDArray.var(\"X\")\n",
"assume_dtype(_NDArray_1, DType.float64)\n",
"assume_shape(_NDArray_1, TupleInt.from_vec(Vec[Int](Int(150), Int(4))))\n",
"assume_isfinite(_NDArray_1)\n",
"_NDArray_2 = NDArray.var(\"y\")\n",
"assume_dtype(_NDArray_2, DType.int64)\n",
"assume_shape(_NDArray_2, TupleInt.from_vec(Vec[Int](Int(150))))\n",
"assume_value_one_of(_NDArray_2, TupleValue.from_vec(Vec[Value](Value.int(Int(0)), Value.int(Int(1)), Value.int(Int(2)))))\n",
"_NDArray_3 = astype(\n",
" NDArray.vector(\n",
" TupleValue.from_vec(\n",
" Vec[Value](\n",
" sum(_NDArray_2 == NDArray.scalar(Value.int(Int(0)))).to_value(),\n",
" sum(_NDArray_2 == NDArray.scalar(Value.int(Int(1)))).to_value(),\n",
" sum(_NDArray_2 == NDArray.scalar(Value.int(Int(2)))).to_value(),\n",
" )\n",
" )\n",
" ),\n",
" DType.float64,\n",
") / NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string(\"150\"), BigInt.from_string(\"1\")))))\n",
"_NDArray_4 = zeros(TupleInt.from_vec(Vec[Int](Int(3), Int(4))), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))\n",
"_MultiAxisIndexKeyItem_1 = MultiAxisIndexKeyItem.slice(Slice())\n",
"_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1)))\n",
"_NDArray_5 = _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(0))))]\n",
"_OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0)))\n",
"_NDArray_4[_IndexKey_1] = sum(_NDArray_5, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_5.shape[Int(0)]))\n",
"_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(1)), _MultiAxisIndexKeyItem_1)))\n",
"_NDArray_6 = _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(1))))]\n",
"_NDArray_4[_IndexKey_2] = sum(_NDArray_6, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_6.shape[Int(0)]))\n",
"_IndexKey_3 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1)))\n",
"_NDArray_7 = _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(2))))]\n",
"_NDArray_4[_IndexKey_3] = sum(_NDArray_7, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_7.shape[Int(0)]))\n",
"_NDArray_8 = concat(\n",
" TupleNDArray.from_vec(Vec[NDArray](_NDArray_5 - _NDArray_4[_IndexKey_1], _NDArray_6 - _NDArray_4[_IndexKey_2], _NDArray_7 - _NDArray_4[_IndexKey_3])), OptionalInt.some(Int(0))\n",
")\n",
"_NDArray_9 = square(_NDArray_8 - expand_dims(sum(_NDArray_8, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_8.shape[Int(0)]))))\n",
"_NDArray_10 = sqrt(sum(_NDArray_9, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_9.shape[Int(0)])))\n",
"_NDArray_11 = copy(_NDArray_10)\n",
"_NDArray_11[IndexKey.ndarray(_NDArray_10 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(\n",
" Value.float(Float.rational(BigRat(BigInt.from_string(\"1\"), BigInt.from_string(\"1\"))))\n",
")\n",
"_TupleNDArray_1 = svd(\n",
" sqrt(\n",
" asarray(\n",
" NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string(\"1\"), BigInt.from_string(\"147\"))))),\n",
" OptionalDType.some(DType.float64),\n",
" OptionalBool.none,\n",
" OptionalDevice.some(_NDArray_1.device),\n",
" )\n",
" )\n",
" * (_NDArray_8 / _NDArray_11),\n",
" Boolean(False),\n",
")\n",
"_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))\n",
"_NDArray_12 = (\n",
" _TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))]\n",
" / _NDArray_11\n",
").T / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)]\n",
"_TupleNDArray_2 = svd(\n",
" (\n",
" sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_3) * NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string(\"1\"), BigInt.from_string(\"2\"))))))\n",
" * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T\n",
" ).T\n",
" @ _NDArray_12,\n",
" Boolean(False),\n",
")\n",
"(\n",
" (_NDArray_1 - (_NDArray_3 @ _NDArray_4))\n",
" @ (\n",
" _NDArray_12\n",
" @ _TupleNDArray_2[Int(2)].T[\n",
" IndexKey.multi_axis(\n",
" MultiAxisIndexKey.from_vec(\n",
" Vec[MultiAxisIndexKeyItem](\n",
" _MultiAxisIndexKeyItem_1,\n",
" MultiAxisIndexKeyItem.slice(\n",
" Slice(\n",
" OptionalInt.none,\n",
" OptionalInt.some(\n",
" sum(astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32))\n",
" .to_value()\n",
" .to_int\n",
" ),\n",
" )\n",
" ),\n",
" )\n",
" )\n",
" )\n",
" ]\n",
" )\n",
")[\n",
" IndexKey.multi_axis(\n",
" MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](_MultiAxisIndexKeyItem_1, MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(2))))))\n",
" )\n",
"]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from egglog.exp.array_api_numba import array_api_numba_schedule\n",
"\n",
"egraph.register(res)\n",
"egraph.run(array_api_numba_schedule)\n",
"simplified_res = egraph.extract(res)\n",
"\n",
"simplified_res"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Program Gen\n",
"\n",
"_Generate an imperative program from your e-graph with replacement rules that walk the graph in a fixed order_\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have a program, what do we do with it?\n",
"\n",
"Well we showed how we can use eager evaluation to get a result, but what if we don't want to do the computation in egglog, but instead export a program so we can execute that back in Python or in this case feed it to Python?\n",
"\n",
"Well in this case we have designed a `Program` object which we can use to convert a funtional egglog expression back to imperative Python code:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def __fn(X, y):\n",
" assert X.dtype == np.dtype(np.float64)\n",
" assert X.shape == (150, 4, )\n",
" assert np.all(np.isfinite(X))\n",
" assert y.dtype == np.dtype(np.int64)\n",
" assert y.shape == (150, )\n",
" assert set(np.unique(y)) == set((0, 1, 2, ))\n",
" _0 = y == np.array(0)\n",
" _1 = np.sum(_0)\n",
" _2 = y == np.array(1)\n",
" _3 = np.sum(_2)\n",
" _4 = y == np.array(2)\n",
" _5 = np.sum(_4)\n",
" _6 = np.array((_1, _3, _5, )).astype(np.dtype(np.float64))\n",
" _7 = _6 / np.array(float(150))\n",
" _8 = np.zeros((3, 4, ), dtype=np.dtype(np.float64))\n",
" _9 = np.sum(X[_0], axis=0)\n",
" _10 = _9 / np.array(X[_0].shape[0])\n",
" _8[0, :,] = _10\n",
" _11 = np.sum(X[_2], axis=0)\n",
" _12 = _11 / np.array(X[_2].shape[0])\n",
" _8[1, :,] = _12\n",
" _13 = np.sum(X[_4], axis=0)\n",
" _14 = _13 / np.array(X[_4].shape[0])\n",
" _8[2, :,] = _14\n",
" _15 = _7 @ _8\n",
" _16 = X - _15\n",
" _17 = np.sqrt(np.asarray(np.array(float(1 / 147)), np.dtype(np.float64)))\n",
" _18 = X[_0] - _8[0, :,]\n",
" _19 = X[_2] - _8[1, :,]\n",
" _20 = X[_4] - _8[2, :,]\n",
" _21 = np.concatenate((_18, _19, _20, ), axis=0)\n",
" _22 = np.sum(_21, axis=0)\n",
" _23 = _22 / np.array(_21.shape[0])\n",
" _24 = np.expand_dims(_23, 0)\n",
" _25 = _21 - _24\n",
" _26 = np.square(_25)\n",
" _27 = np.sum(_26, axis=0)\n",
" _28 = _27 / np.array(_26.shape[0])\n",
" _29 = np.sqrt(_28)\n",
" _30 = _29 == np.array(0)\n",
" _29[_30] = np.array(float(1))\n",
" _31 = _21 / _29\n",
" _32 = _17 * _31\n",
" _33 = np.linalg.svd(_32, full_matrices=False)\n",
" _34 = _33[1] > np.array(0.0001)\n",
" _35 = _34.astype(np.dtype(np.int32))\n",
" _36 = np.sum(_35)\n",
" _37 = _33[2][:_36, :,] / _29\n",
" _38 = _37.T / _33[1][:_36]\n",
" _39 = np.array(150) * _7\n",
" _40 = _39 * np.array(float(1 / 2))\n",
" _41 = np.sqrt(_40)\n",
" _42 = _8 - _15\n",
" _43 = _41 * _42.T\n",
" _44 = _43.T @ _38\n",
" _45 = np.linalg.svd(_44, full_matrices=False)\n",
" _46 = np.array(0.0001) * _45[1][0]\n",
" _47 = _45[1] > _46\n",
" _48 = _47.astype(np.dtype(np.int32))\n",
" _49 = np.sum(_48)\n",
" _50 = _38 @ _45[2].T[:, :_49,]\n",
" _51 = _16 @ _50\n",
" return _51[:, :2,]\n",
"\n"
]
}
],
"source": [
"from egglog.exp.array_api_program_gen import *\n",
"import numpy as np\n",
"import inspect\n",
"\n",
"\n",
"egraph = EGraph()\n",
"fn_program = egraph.let(\n",
" \"fn_program\",\n",
" EvalProgram(ndarray_function_two_program(simplified_res, NDArray.var(\"X\"), NDArray.var(\"y\")), {\"np\": np}),\n",
")\n",
"egraph.run(array_api_program_gen_schedule)\n",
"fn = egraph.extract(fn_program.as_py_object).value\n",
"\n",
"print(inspect.getsource(fn))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From there we can complete our work, by optimizing with numba and we can call with our original values:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/egglog-39d7fedf-0738-45c9-8c02-1e5921b60da3.py:62: NumbaPerformanceWarning: '@' is faster on contiguous arrays, called on (Array(float64, 2, 'C', False, aligned=True), Array(float64, 2, 'A', False, aligned=True))\n",
" _50 = _38 @ _45[2].T[:, :_49,]\n"
]
},
{
"data": {
"text/plain": [
"array([[ 8.06179978e+00, 3.00420621e-01],\n",
" [ 7.12868772e+00, -7.86660426e-01],\n",
" [ 7.48982797e+00, -2.65384488e-01],\n",
" [ 6.81320057e+00, -6.70631068e-01],\n",
" [ 8.13230933e+00, 5.14462530e-01],\n",
" [ 7.70194674e+00, 1.46172097e+00],\n",
" [ 7.21261762e+00, 3.55836209e-01],\n",
" [ 7.60529355e+00, -1.16338380e-02],\n",
" [ 6.56055159e+00, -1.01516362e+00],\n",
" [ 7.34305989e+00, -9.47319209e-01],\n",
" [ 8.39738652e+00, 6.47363392e-01],\n",
" [ 7.21929685e+00, -1.09646389e-01],\n",
" [ 7.32679599e+00, -1.07298943e+00],\n",
" [ 7.57247066e+00, -8.05464137e-01],\n",
" [ 9.84984300e+00, 1.58593698e+00],\n",
" [ 9.15823890e+00, 2.73759647e+00],\n",
" [ 8.58243141e+00, 1.83448945e+00],\n",
" [ 7.78075375e+00, 5.84339407e-01],\n",
" [ 8.07835876e+00, 9.68580703e-01],\n",
" [ 8.02097451e+00, 1.14050366e+00],\n",
" [ 7.49680227e+00, -1.88377220e-01],\n",
" [ 7.58648117e+00, 1.20797032e+00],\n",
" [ 8.68104293e+00, 8.77590154e-01],\n",
" [ 6.25140358e+00, 4.39696367e-01],\n",
" [ 6.55893336e+00, -3.89222752e-01],\n",
" [ 6.77138315e+00, -9.70634453e-01],\n",
" [ 6.82308032e+00, 4.63011612e-01],\n",
" [ 7.92461638e+00, 2.09638715e-01],\n",
" [ 7.99129024e+00, 8.63787128e-02],\n",
" [ 6.82946447e+00, -5.44960851e-01],\n",
" [ 6.75895493e+00, -7.59002759e-01],\n",
" [ 7.37495254e+00, 5.65844592e-01],\n",
" [ 9.12634625e+00, 1.22443267e+00],\n",
" [ 9.46768199e+00, 1.82522635e+00],\n",
" [ 7.06201386e+00, -6.63400423e-01],\n",
" [ 7.95876243e+00, -1.64961722e-01],\n",
" [ 8.61367201e+00, 4.03253602e-01],\n",
" [ 8.33041759e+00, 2.28133530e-01],\n",
" [ 6.93412007e+00, -7.05519379e-01],\n",
" [ 7.68823131e+00, -9.22362309e-03],\n",
" [ 7.91793715e+00, 6.75121313e-01],\n",
" [ 5.66188065e+00, -1.93435524e+00],\n",
" [ 7.24101468e+00, -2.72615132e-01],\n",
" [ 6.41443556e+00, 1.24730131e+00],\n",
" [ 6.85944381e+00, 1.05165396e+00],\n",
" [ 6.76470393e+00, -5.05151855e-01],\n",
" [ 8.08189937e+00, 7.63392750e-01],\n",
" [ 7.18676904e+00, -3.60986823e-01],\n",
" [ 8.31444876e+00, 6.44953177e-01],\n",
" [ 7.67196741e+00, -1.34893840e-01],\n",
" [-1.45927545e+00, 2.85437643e-02],\n",
" [-1.79770574e+00, 4.84385502e-01],\n",
" [-2.41694888e+00, -9.27840307e-02],\n",
" [-2.26247349e+00, -1.58725251e+00],\n",
" [-2.54867836e+00, -4.72204898e-01],\n",
" [-2.42996725e+00, -9.66132066e-01],\n",
" [-2.44848456e+00, 7.95961954e-01],\n",
" [-2.22666513e-01, -1.58467318e+00],\n",
" [-1.75020123e+00, -8.21180130e-01],\n",
" [-1.95842242e+00, -3.51563753e-01],\n",
" [-1.19376031e+00, -2.63445570e+00],\n",
" [-1.85892567e+00, 3.19006544e-01],\n",
" [-1.15809388e+00, -2.64340991e+00],\n",
" [-2.66605725e+00, -6.42504540e-01],\n",
" [-3.78367218e-01, 8.66389312e-02],\n",
" [-1.20117255e+00, 8.44373592e-02],\n",
" [-2.76810246e+00, 3.21995363e-02],\n",
" [-7.76854039e-01, -1.65916185e+00],\n",
" [-3.49805433e+00, -1.68495616e+00],\n",
" [-1.09042788e+00, -1.62658350e+00],\n",
" [-3.71589615e+00, 1.04451442e+00],\n",
" [-9.97610366e-01, -4.90530602e-01],\n",
" [-3.83525931e+00, -1.40595806e+00],\n",
" [-2.25741249e+00, -1.42679423e+00],\n",
" [-1.25571326e+00, -5.46424197e-01],\n",
" [-1.43755762e+00, -1.34424979e-01],\n",
" [-2.45906137e+00, -9.35277280e-01],\n",
" [-3.51848495e+00, 1.60588866e-01],\n",
" [-2.58979871e+00, -1.74611728e-01],\n",
" [ 3.07487884e-01, -1.31887146e+00],\n",
" [-1.10669179e+00, -1.75225371e+00],\n",
" [-6.05524589e-01, -1.94298038e+00],\n",
" [-8.98703769e-01, -9.04940034e-01],\n",
" [-4.49846635e+00, -8.82749915e-01],\n",
" [-2.93397799e+00, 2.73791065e-02],\n",
" [-2.10360821e+00, 1.19156767e+00],\n",
" [-2.14258208e+00, 8.87797815e-02],\n",
" [-2.47945603e+00, -1.94073927e+00],\n",
" [-1.32552574e+00, -1.62869550e-01],\n",
" [-1.95557887e+00, -1.15434826e+00],\n",
" [-2.40157020e+00, -1.59458341e+00],\n",
" [-2.29248878e+00, -3.32860296e-01],\n",
" [-1.27227224e+00, -1.21458428e+00],\n",
" [-2.93176055e-01, -1.79871509e+00],\n",
" [-2.00598883e+00, -9.05418042e-01],\n",
" [-1.18166311e+00, -5.37570242e-01],\n",
" [-1.61615645e+00, -4.70103580e-01],\n",
" [-1.42158879e+00, -5.51244626e-01],\n",
" [ 4.75973788e-01, -7.99905482e-01],\n",
" [-1.54948259e+00, -5.93363582e-01],\n",
" [-7.83947399e+00, 2.13973345e+00],\n",
" [-5.50747997e+00, -3.58139892e-02],\n",
" [-6.29200850e+00, 4.67175777e-01],\n",
" [-5.60545633e+00, -3.40738058e-01],\n",
" [-6.85055995e+00, 8.29825394e-01],\n",
" [-7.41816784e+00, -1.73117995e-01],\n",
" [-4.67799541e+00, -4.99095015e-01],\n",
" [-6.31692685e+00, -9.68980756e-01],\n",
" [-6.32773684e+00, -1.38328993e+00],\n",
" [-6.85281335e+00, 2.71758963e+00],\n",
" [-4.44072512e+00, 1.34723692e+00],\n",
" [-5.45009572e+00, -2.07736942e-01],\n",
" [-5.66033713e+00, 8.32713617e-01],\n",
" [-5.95823722e+00, -9.40175447e-02],\n",
" [-6.75926282e+00, 1.60023206e+00],\n",
" [-5.80704331e+00, 2.01019882e+00],\n",
" [-5.06601233e+00, -2.62733839e-02],\n",
" [-6.60881882e+00, 1.75163587e+00],\n",
" [-9.17147486e+00, -7.48255067e-01],\n",
" [-4.76453569e+00, -2.15573720e+00],\n",
" [-6.27283915e+00, 1.64948141e+00],\n",
" [-5.36071189e+00, 6.46120732e-01],\n",
" [-7.58119982e+00, -9.80722934e-01],\n",
" [-4.37150279e+00, -1.21297458e-01],\n",
" [-5.72317531e+00, 1.29327553e+00],\n",
" [-5.27915920e+00, -4.24582377e-02],\n",
" [-4.08087208e+00, 1.85936572e-01],\n",
" [-4.07703640e+00, 5.23238483e-01],\n",
" [-6.51910397e+00, 2.96976389e-01],\n",
" [-4.58371942e+00, -8.56815813e-01],\n",
" [-6.22824009e+00, -7.12719638e-01],\n",
" [-5.22048773e+00, 1.46819509e+00],\n",
" [-6.80015000e+00, 5.80895175e-01],\n",
" [-3.81515972e+00, -9.42985932e-01],\n",
" [-5.10748966e+00, -2.13059000e+00],\n",
" [-6.79671631e+00, 8.63090395e-01],\n",
" [-6.52449599e+00, 2.44503527e+00],\n",
" [-4.99550279e+00, 1.87768525e-01],\n",
" [-3.93985300e+00, 6.14020389e-01],\n",
" [-5.20383090e+00, 1.14476808e+00],\n",
" [-6.65308685e+00, 1.80531976e+00],\n",
" [-5.10555946e+00, 1.99218201e+00],\n",
" [-5.50747997e+00, -3.58139892e-02],\n",
" [-6.79601924e+00, 1.46068695e+00],\n",
" [-6.84735943e+00, 2.42895067e+00],\n",
" [-5.64500346e+00, 1.67771734e+00],\n",
" [-5.17956460e+00, -3.63475041e-01],\n",
" [-4.96774090e+00, 8.21140550e-01],\n",
" [-5.88614539e+00, 2.34509051e+00],\n",
" [-4.68315426e+00, 3.32033811e-01]])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from numba import njit\n",
"\n",
"njit(fn)(X_np, y_np)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These program rewrites work by first translating the code into an intermediate IR of just program assignments and statements, and then turning this into source. It does this by walking the graph first top to bottom, recording the first parent that saw every child. Then it goes from bottom to top, building up a larger set of statements as well as an expression representing each node. If a node has already been emitted somewhere else, we record that, and we always wait for all the children to complete before moving forward. That way, we can enforce some ordering and we know that everything will be emitted only once.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example\n",
"\n",
"Here is a small example, we might want to compile, let's say we want a function like this:\n",
"\n",
"```python\n",
"def __fn(x, y)\n",
" x[1] = 10\n",
" z = x + x\n",
" return sum(z) + y\n",
"```\n",
"\n",
"We can build up a functional program for this and then compile it to Python source:\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def __fn(x, y):\n",
" x[1] = 10\n",
" _0 = x + x\n",
" return sum(_0) + y\n",
"\n"
]
}
],
"source": [
"from egglog.exp.program_gen import *\n",
"\n",
"x = Program(\"x\", is_identifier=True)\n",
"y = Program(\"y\", is_identifier=True)\n",
"# To reference x, we need to first emit the statement\n",
"x_modified = Program(\"x\").statement(x + \"[1] = 10\")\n",
"\n",
"z = (x_modified + \" + \" + x_modified).assign()\n",
"res = Program(\"sum(\") + z + \") + \" + y\n",
"fn = res.function_two(x, y)\n",
"\n",
"egraph = EGraph()\n",
"egraph.register(fn.compile())\n",
"egraph.run(program_gen_ruleset.saturate())\n",
"print(egraph.extract(fn.statements).value)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This happens, by first going top to bottom with the `compile`, which will put a total ordering on all nodes, by defining one, and only one, parent expression for each expression.\n",
"\n",
"Then, from the bottom up, for each node we compute an expression string and a list of statements string.\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2a45463df1ca40a2ac30f6f50da2b7b4",
"version_major": 2,
"version_minor": 1
},
"text/plain": [
"VisualizerWidget(egraphs=['{\"nodes\":{\"function-0-Program_statement\":{\"op\":\"·.statement\",\"children\":[\"function-…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"egraph"
]
},
{
"cell_type": "markdown",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## PyObject - Python objects in EGraphs\n",
"\n",
"_We can also add Python objects directly to the e-graph as primitives and run rewrite rules that call back into Python_\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Future Work\n",
"\n",
"The next milestone use case is to be able to optimize functional array programs and rewrite them.\n",
"\n",
"To implement this we need to at least support **functions as values** and ideally also **generic types**.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example\n",
"\n",
"There is a concrete example [provided by Siu from the Numba project](https://gist.github.com/sklam/5e5737137d48d6e5b816d14a90076f1d).\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"We would want users to be able to write code like this:\n",
"\n",
"```python\n",
"def linalg_norm_loopnest_egglog(X: enp.NDArray, axis: enp.TupleInt) -> enp.NDArray:\n",
" # peel off the outer shape for result array\n",
" outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()\n",
" # get only the inner shape for reduction\n",
" reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()\n",
"\n",
" return enp.NDArray.from_fn(\n",
" outshape,\n",
" X.dtype,\n",
" lambda k: enp.sqrt(\n",
" LoopNestAPI.from_tuple(reduce_axis)\n",
" .unwrap()\n",
" .reduce(lambda carry, i: carry + enp.real(enp.conj(x := X[i + k]) * x), init=0.0)\n",
" ).to_value(),\n",
" )\n",
"```\n",
"\n",
"Which would then be rewritten to:\n",
"\n",
"```python\n",
"def linalg_norm_array_api(X: enp.NDArray, axis: enp.TupleInt) -> enp.NDArray:\n",
" outdim = enp.range_(X.ndim).filter(lambda x: ~axis.contains(x))\n",
" outshape = convert(convert(X.shape, enp.NDArray)[outdim], enp.TupleInt)\n",
" row_axis, col_axis = axis\n",
" return enp.NDArray.from_fn(\n",
" outshape,\n",
" X.dtype,\n",
" lambda k: enp.sqrt(\n",
" enp.int_product(enp.range_(X.shape[row_axis]), enp.range_(X.shape[col_axis]))\n",
" .map_to_ndarray(lambda rc: enp.real(enp.conj(x := X[rc + k]) * x))\n",
" .sum()\n",
" ).to_value(),\n",
" )\n",
"```\n",
"\n",
"And finally could be lowered to Python as:\n",
"\n",
"```python\n",
"def linalg_norm_low_level(\n",
" X: np.ndarray[tuple, np.dtype[np.float64]], axis: tuple[int, int]\n",
") -> np.ndarray[tuple, np.dtype[np.float64]]:\n",
" # # If X ndim>=3 and axis is a 2-tuple\n",
" assert X.ndim >= 3\n",
" assert len(axis) == 2\n",
" # Y - 2\n",
" outdim = [dim for dim in range(X.ndim) if dim not in axis]\n",
"\n",
" outshape = tuple(np.asarray(X.shape)[outdim])\n",
"\n",
" res = np.zeros(outshape, dtype=X.dtype)\n",
" row_axis, col_axis = axis\n",
" for k in np.ndindex(outshape):\n",
" tmp = 0.0\n",
" for row in range(X.shape[row_axis]):\n",
" for col in range(X.shape[col_axis]):\n",
" idx = (row, col, *k)\n",
" x = X[idx]\n",
" tmp += (x.conj() * x).real\n",
" res[k] = np.sqrt(tmp)\n",
" return res\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this talk I have gone through some details of what is needed to connect data science users to egglog:\n",
"\n",
"\n",
"\n",
"Overall, the idea is that if we can get egglog in more users hands, in particular for data intensive workloads where the tradeoff of time for pre-computation is worth it, than this can help drive exciting future research directions and also build meaningful useful tools for the scientific open source ecosystem in Python.\n",
"\n",
"**If you are building DSLs in Python, or more generally want to play with e-graphs, try out `egglog-python`!**\n",
"\n",
"Around on the e-graphs Zulip for any questions.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "egglog",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}