Matrix multiplication and Kronecker product.#

%3 outer_cluster_44 cluster_44 outer_cluster_63 cluster_63 outer_cluster_96 cluster_96 outer_cluster_92 cluster_92 outer_cluster_97 cluster_97 outer_cluster_20 cluster_20 outer_cluster_68 cluster_68 outer_cluster_10 cluster_10 outer_cluster_83 cluster_83 outer_cluster_85 cluster_85 outer_cluster_84 cluster_84 outer_cluster_56 cluster_56 outer_cluster_41 cluster_41 outer_cluster_86 cluster_86 outer_cluster_75 cluster_75 outer_cluster_22 cluster_22 outer_cluster_42 cluster_42 outer_cluster_88 cluster_88 outer_cluster_String-7467262353656187021 cluster_String-7467262353656187021 outer_cluster_String-9694278030160592636 cluster_String-9694278030160592636 outer_cluster_String-656473364521123060 cluster_String-656473364521123060 outer_cluster_String-7490637275317345719 cluster_String-7490637275317345719 outer_cluster_String-8202743939950793376 cluster_String-8202743939950793376 outer_cluster_String-12929186563626707020 cluster_String-12929186563626707020 nrows-935348310899929565:s->NamedMat-7467262353656187021 NamedMat-7467262353656187021:s->String-7467262353656187021 nrows-15225087114521759896:s->MMul-8708032918211032778 MMul-8708032918211032778:s->MMul-8708032918211032778 MMul-8708032918211032778:s->Id-103947256882385308 NamedDim-8202743939950793376:s->String-8202743939950793376 ncols-935348310899929565:s->MMul-16294589661251316399 MMul-16294589661251316399:s->NamedMat-7467262353656187021 MMul-16294589661251316399:s->MMul-8708032918211032778 ncols-15225087114521759896:s->MMul-8708032918211032778 nrows-6807129317463932018:s->NamedMat-7490637275317345719 NamedMat-7490637275317345719:s->String-7490637275317345719 nrows-11899482898451582868:s->MMul-9946795743966032656 MMul-9946795743966032656:s->MMul-9946795743966032656 MMul-9946795743966032656:s->Id-987321939341122219 NamedDim-12929186563626707020:s->String-12929186563626707020 ncols-6807129317463932018:s->MMul-16416883487096718523 MMul-16416883487096718523:s->NamedMat-7490637275317345719 MMul-16416883487096718523:s->MMul-9946795743966032656 ncols-11899482898451582868:s->MMul-9946795743966032656 nrows-51973628441192654:s->Kron-10732058337183125065 Kron-10732058337183125065:s->NamedMat-7467262353656187021 Kron-10732058337183125065:s->Id-987321939341122219 nrows-6755155689022739364:s->Kron-17125990715268860656 Kron-17125990715268860656:s->NamedMat-7490637275317345719 Kron-17125990715268860656:s->Id-103947256882385308 nrows-16108461796980496807:s->Kron-10526755911569699254 Kron-10526755911569699254:s->MMul-8708032918211032778 Kron-10526755911569699254:s->MMul-9946795743966032656 nrows-7742477628363861583:s->Kron-17227345883999901159 Kron-17227345883999901159:s->MMul-16294589661251316399 Kron-17227345883999901159:s->MMul-16416883487096718523 Times-5130987815726978091:s->ncols-935348310899929565 Times-5130987815726978091:s->ncols-6807129317463932018 Times-12595996761611282385:s->ncols-15225087114521759896 Times-12595996761611282385:s->ncols-11899482898451582868 ncols-207894513764770616:s->MMul-6270810559860744064 MMul-6270810559860744064:s->Kron-10526755911569699254 MMul-6270810559860744064:s->Kron-6353819159628767751 ncols-51973628441192654:s->Kron-10732058337183125065 ncols-6755155689022739364:s->Kron-17125990715268860656 ncols-16108461796980496807:s->Kron-10526755911569699254 ncols-7742477628363861583:s->Kron-17227345883999901159 nrows-3377577844511369682:s->NamedMat-656473364521123060 NamedMat-656473364521123060:s->String-656473364521123060 nrows-1039295567782314873:s->Id-5248274466311228812 Id-5248274466311228812:s->nrows-3377577844511369682 NamedDim-9694278030160592636:s->String-9694278030160592636 ncols-1039295567782314873:s->Id-5248274466311228812 ncols-3377577844511369682:s->NamedMat-656473364521123060 nrows-207894513764770616:s->MMul-5949049476042656935 MMul-5949049476042656935:s->MMul-5949049476042656935 MMul-5949049476042656935:s->MMul-16173041864095337072 nrows-6911076574346317326:s->Kron-6353819159628767751 Kron-6353819159628767751:s->NamedMat-656473364521123060 Kron-6353819159628767751:s->Id-5248274466311228812 Times-18088207888275264249:s->ncols-1039295567782314873 Times-18088207888275264249:s->ncols-3377577844511369682 ncols-6911076574346317326:s->Kron-6353819159628767751 Id-987321939341122219:s->nrows-6807129317463932018 MMul-13773517637092209600:s->Kron-17125990715268860656 MMul-13773517637092209600:s->MMul-16173041864095337072 MMul-16173041864095337072:s->Kron-10526755911569699254 MMul-16173041864095337072:s->MMul-16173041864095337072 Id-103947256882385308:s->nrows-935348310899929565 MMul-11175367218157191521:s->MMul-13773517637092209600 MMul-11175367218157191521:s->MMul-5416801790308454385 MMul-5416801790308454385:s->Kron-10732058337183125065 MMul-5416801790308454385:s->Kron-10526755911569699254 MMul-17390866032735975329:s->Kron-10732058337183125065 MMul-17390866032735975329:s->Kron-6353819159628767751 MMul-8339305381212904398:s->MMul-5949049476042656935 MMul-8339305381212904398:s->MMul-5416801790308454385 nrows-935348310899929565 ·.nrows NamedMat-7467262353656187021 Matrix.named nrows-15225087114521759896 ·.nrows MMul-8708032918211032778 · @ · NamedDim-8202743939950793376 Dim.named String-8202743939950793376 "n" ncols-935348310899929565 ·.ncols MMul-16294589661251316399 · @ · ncols-15225087114521759896 ·.ncols nrows-6807129317463932018 ·.nrows NamedMat-7490637275317345719 Matrix.named nrows-11899482898451582868 ·.nrows MMul-9946795743966032656 · @ · NamedDim-12929186563626707020 Dim.named String-12929186563626707020 "m" ncols-6807129317463932018 ·.ncols MMul-16416883487096718523 · @ · ncols-11899482898451582868 ·.ncols nrows-51973628441192654 ·.nrows Kron-10732058337183125065 kron nrows-6755155689022739364 ·.nrows Kron-17125990715268860656 kron nrows-16108461796980496807 ·.nrows Kron-10526755911569699254 kron nrows-7742477628363861583 ·.nrows Kron-17227345883999901159 kron Times-5130987815726978091 · * · Times-12595996761611282385 · * · ncols-207894513764770616 ·.ncols MMul-6270810559860744064 · @ · ncols-51973628441192654 ·.ncols ncols-6755155689022739364 ·.ncols ncols-16108461796980496807 ·.ncols ncols-7742477628363861583 ·.ncols nrows-3377577844511369682 ·.nrows NamedMat-656473364521123060 Matrix.named nrows-1039295567782314873 ·.nrows Id-5248274466311228812 Matrix.identity NamedDim-9694278030160592636 Dim.named String-9694278030160592636 "p" ncols-1039295567782314873 ·.ncols ncols-3377577844511369682 ·.ncols nrows-207894513764770616 ·.nrows MMul-5949049476042656935 · @ · nrows-6911076574346317326 ·.nrows Kron-6353819159628767751 kron Times-18088207888275264249 · * · ncols-6911076574346317326 ·.ncols Id-987321939341122219 Matrix.identity MMul-13773517637092209600 · @ · MMul-16173041864095337072 · @ · Id-103947256882385308 Matrix.identity String-656473364521123060 "C" MMul-11175367218157191521 · @ · MMul-5416801790308454385 · @ · String-7467262353656187021 "A" MMul-17390866032735975329 · @ · MMul-8339305381212904398 · @ · String-7490637275317345719 "B"


from __future__ import annotations

from egglog import *

egraph = EGraph()


class Dim(Expr):
    """
    A dimension of a matix.

    >>> Dim(3) * Dim.named("n")
    Dim(3) * Dim.named("n")
    """

    @method(egg_fn="Lit")
    def __init__(self, value: i64Like) -> None: ...

    @method(egg_fn="NamedDim")
    @classmethod
    def named(cls, name: StringLike) -> Dim:  # type: ignore[empty-body]
        ...

    @method(egg_fn="Times")
    def __mul__(self, other: Dim) -> Dim:  # type: ignore[empty-body]
        ...


a, b, c, n = vars_("a b c n", Dim)
i, j = vars_("i j", i64)
egraph.register(
    rewrite(a * (b * c)).to((a * b) * c),
    rewrite((a * b) * c).to(a * (b * c)),
    rewrite(Dim(i) * Dim(j)).to(Dim(i * j)),
    rewrite(a * b).to(b * a),
)


class Matrix(Expr, egg_sort="MExpr"):
    @method(egg_fn="Id")
    @classmethod
    def identity(cls, dim: Dim) -> Matrix:  # type: ignore[empty-body]
        """
        Create an identity matrix of the given dimension.
        """

    @method(egg_fn="NamedMat")
    @classmethod
    def named(cls, name: StringLike) -> Matrix:  # type: ignore[empty-body]
        """
        Create a named matrix.
        """

    @method(egg_fn="MMul")
    def __matmul__(self, other: Matrix) -> Matrix:  # type: ignore[empty-body]
        """
        Matrix multiplication.
        """

    @method(egg_fn="nrows")
    def nrows(self) -> Dim:  # type: ignore[empty-body]
        """
        Number of rows in the matrix.
        """

    @method(egg_fn="ncols")
    def ncols(self) -> Dim:  # type: ignore[empty-body]
        """
        Number of columns in the matrix.
        """


@function(egg_fn="Kron")
def kron(a: Matrix, b: Matrix) -> Matrix:  # type: ignore[empty-body]
    """
    Kronecker product of two matrices.

    https://en.wikipedia.org/wiki/Kronecker_product#Definition
    """


A, B, C, D = vars_("A B C D", Matrix)
egraph.register(
    # The dimensions of a kronecker product are the product of the dimensions
    rewrite(kron(A, B).nrows()).to(A.nrows() * B.nrows()),
    rewrite(kron(A, B).ncols()).to(A.ncols() * B.ncols()),
    # The dimensions of a matrix multiplication are the number of rows of the first
    # matrix and the number of columns of the second matrix.
    rewrite((A @ B).nrows()).to(A.nrows()),
    rewrite((A @ B).ncols()).to(B.ncols()),
    # The dimensions of an identity matrix are the input dimension
    rewrite(Matrix.identity(n).nrows()).to(n),
    rewrite(Matrix.identity(n).ncols()).to(n),
)
egraph.register(
    # Multiplication by an identity matrix is the same as the other matrix
    rewrite(Matrix.identity(n) @ A).to(A),
    rewrite(A @ Matrix.identity(n)).to(A),
    # Matrix multiplication is associative
    rewrite(A @ (B @ C)).to((A @ B) @ C),
    rewrite((A @ B) @ C).to(A @ (B @ C)),
    # Kronecker product is associative
    rewrite(kron(A, kron(B, C))).to(kron(kron(A, B), C)),
    rewrite(kron(kron(A, B), C)).to(kron(A, kron(B, C))),
    # Kronecker product distributes over matrix multiplication
    rewrite(kron(A @ C, B @ D)).to(kron(A, B) @ kron(C, D)),
    rewrite(kron(A, B) @ kron(C, D)).to(
        kron(A @ C, B @ D),
        # Only when the dimensions match
        eq(A.ncols()).to(C.nrows()),
        eq(B.ncols()).to(D.nrows()),
    ),
)
egraph.register(
    # demand rows and columns when we multiply matrices
    rule(eq(C).to(A @ B)).then(
        A.ncols(),
        A.nrows(),
        B.ncols(),
        B.nrows(),
    ),
    # demand rows and columns when we take the kronecker product
    rule(eq(C).to(kron(A, B))).then(
        A.ncols(),
        A.nrows(),
        B.ncols(),
        B.nrows(),
    ),
)


# Define a number of dimensions
n = egraph.let("n", Dim.named("n"))
m = egraph.let("m", Dim.named("m"))
p = egraph.let("p", Dim.named("p"))

# Define a number of matrices
A = egraph.let("A", Matrix.named("A"))
B = egraph.let("B", Matrix.named("B"))
C = egraph.let("C", Matrix.named("C"))

# Set each to be a square matrix of the given dimension
egraph.register(
    union(A.nrows()).with_(n),
    union(A.ncols()).with_(n),
    union(B.nrows()).with_(m),
    union(B.ncols()).with_(m),
    union(C.nrows()).with_(p),
    union(C.ncols()).with_(p),
)
# Create an example which should equal the kronecker product of A and B
ex1 = egraph.let("ex1", kron(Matrix.identity(n), B) @ kron(A, Matrix.identity(m)))
rows = egraph.let("rows", ex1.nrows())
cols = egraph.let("cols", ex1.ncols())

egraph.run(20)

egraph.check(eq(B.nrows()).to(m))
egraph.check(eq(kron(Matrix.identity(n), B).nrows()).to(n * m))

# Verify it matches the expected result
simple_ex1 = egraph.let("simple_ex1", kron(A, B))
egraph.check(eq(ex1).to(simple_ex1))

ex2 = egraph.let("ex2", kron(Matrix.identity(p), C) @ kron(A, Matrix.identity(m)))

egraph.run(10)
# Verify it is not simplified
egraph.check_fail(eq(ex2).to(kron(A, C)))
egraph

Total running time of the script: (0 minutes 0.066 seconds)

Gallery generated by Sphinx-Gallery