Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variable cost to variant #399

Open
AzizZayed opened this issue Jul 26, 2024 · 9 comments
Open

Variable cost to variant #399

AzizZayed opened this issue Jul 26, 2024 · 9 comments

Comments

@AzizZayed
Copy link

AzizZayed commented Jul 26, 2024

How do we assign a variable cost to a variant? For example

(datatype Expr
    (Var String)
    (Num i64) ; I want to set the cost of a Num to be its input
    (Abs Expr)
    (Add Expr Expr)
    (Mul Expr Expr))

(Num 10) ; I want this to have cost of 10
(Num 55) ; I want this to have cost of 55
@saulshanabrook
Copy link
Member

It's not currently possible. We had two PRs with different ways of adding custom costs #355 #353 but punted on it, because we wanted a better idea first of what the space was of user needs here.

If you could give a larger example how you would use this, that would be helpful in designing a solution around it.

@AzizZayed
Copy link
Author

Sure, here is my use case: I want to optimize tensor multiplications according to the shape of the tensors. Assume you have 3 tensors $X$ with shape $a \times b$, $Y$ with shape $b \times c$ and $Z$ with shape $c \times d$. The number of multiplications (I define as the cost) of $(XY)Z$ is different from the cost of $X(YZ)$, even though they are equivalent expressions. The cost of $(XY)Z$ is $abc+acd$, and the cost of $X(YZ)$ is $bcd+abd$. I want to implement within egglog a way to extract the lowest cost tensor operation.

@AzizZayed
Copy link
Author

I guess in this case I would need a dimension analysis, then set the cost of each operation using that analysis. Do you have an example with dimension analysis?

@saulshanabrook
Copy link
Member

Thanks!

Do you have an example with dimension analysis?

Sort of: https://github.com/egraphs-good/egglog/blob/main/tests/matrix.egg

@AzizZayed
Copy link
Author

AzizZayed commented Jul 28, 2024

Let's say I have the following datatype and rewrite rule:

(dataype MatrixOp
   (Matrix String i64 i64)  ; (<name> <nrows> <ncols>)
   (MatMul MatrixOp MatrixOp)
)

(birewrite
    (MatMul ?a (MatMul ?b ?c))
    (MatMul (MatMul ?a ?b) ?c)
)

;  below here I create some ops, run the rewrite rules and extract
...

Using the current egglog ways, how would I tell the extract command to return the MatMul operation with the least number of multiplications? How to calculate the number of multiplications:

Assume you have 3 tensors X with shape a×b, Y with shape b×c and Z with shape c×d. The number of multiplications (I define as the cost) of (XY)Z is different from the cost of X(YZ), even though they are equivalent expressions. The cost of (XY)Z is abc+acd, and the cost of X(YZ) is bcd+abd.

If I can't do this with the extract function, how can I achieve this goal with the current egglog commands?

@saulshanabrook
Copy link
Member

The only command that can be used to influence extraction at a per node basis is subsume: #301 (or I guess delete). But I am not sure if it's possible to do what you want with those commands, they weren't made for this kind of situation.

So am I understanding it correctly that you want to model the cost of a MatMul expression of two matrices a x b and b x c as a * b * c?

So I think with #355 your case could maybe be supported like this?

(dataype MatrixOp
   (Matrix String i64 i64)  ; (<name> <nrows> <ncols>)
   (MatMul MatrixOp MatrixOp)
)


(birewrite
    (MatMul ?a (MatMul ?b ?c))
    (MatMul (MatMul ?a ?b) ?c)
)

(function nrows (MatrixOp) i64)
(function ncols (MatrixOp) i64)

(rule ((= ?m (Matrix ?s ?r ?c)))
       ((set (nrows ?m) ?r)
        (set (ncols ?m) ?c)))

(rule ((= ?m (MatMul ?a ?b))
       (= ?r (nrows ?a))
       (= ?m (ncols ?a) (nrows ?b))
       (= ?c (ncols ?b)))
      ((set (nrows ?m) ?r)
       (set (ncols ?m) ?c)
       (cost (MathMul ?a ?b) (* (* r m) c))))

Does that seem right?

@AzizZayed
Copy link
Author

AzizZayed commented Jul 30, 2024

Does that seem right?

Yes, this looks right.

@AzizZayed
Copy link
Author

AzizZayed commented Jul 30, 2024

A way to do it with the current comands is maybe a conditional rewrite? so

(function nrows (MatrixOp) i64)
(function ncols (MatrixOp) i64)

(rule ((= ?m (MatMul ?x (MatMul ?y ?z)))
       (= ?a (nrows ?x))
       (= ?b (ncols ?x) (nrows y))
       (= ?c (ncols ?y) (nrows ?z))
       (= ?d (ncols ?z))
       [> (* ?b ?d (+ ?a ?c)) (* ?a ?c (+ ?b ?d))])
      ((set/union (MatMul ?x (MatMul ?y ?z)) (MatMul (MatMul ?x ?y) ?z))))

(rule ((= ?m (MatMul ?x (MatMul ?y ?z)))
       (= ?a (nrows ?x))
       (= ?b (ncols ?x) (nrows y))
       (= ?c (ncols ?y) (nrows ?z))
       (= ?d (ncols ?z))
       [> (* ?a ?c (+ ?b ?d)) (* ?b ?d (+ ?a ?c))])
      ((set/union (MatMul (MatMul ?x ?y) ?z)) (MatMul ?x (MatMul ?y ?z)) ))

Assuming

  • cost of $(XY)Z$ is $abc+acd$
  • cost of $X(YZ)$ is $bcd+abd$

Then

  • rewrite $(XY)Z$ to $X(YZ)$ if $abc+acd$ < $bcd+abd$
  • rewrite $X(YZ)$ to $(XY)Z$ if $bcd+abd$ < $abc+acd$

@saulshanabrook
Copy link
Member

saulshanabrook commented Jul 31, 2024

Oh yeah good idea! I think you can do something like that with the :subsume keyword to rewrite which will end up desugaring to something like the rule you wrote, plus "subsuming" the LHS (meaning that it can not be extracted or matched again, it's like a permanent delete) but is a bit more succinct to write:

(dataype MatrixOp
   (Matrix String i64 i64)  ; (<name> <nrows> <ncols>)
   (MatMul MatrixOp MatrixOp)
)


(function nrows (MatrixOp) i64)
(function ncols (MatrixOp) i64)

(rule ((= ?m (Matrix ?s ?r ?c)))
       ((set (nrows ?m) ?r)
        (set (ncols ?m) ?c)))

(rule ((= ?m (MatMul ?a ?b))
       (= ?r (nrows ?a))
       (= ?c (ncols ?b)))
      ((set (nrows ?m) ?r)
       (set (ncols ?m) ?c)))


(rewrite
    (MatMul ?x (MatMul ?y ?z))
    (MatMul (MatMul ?x ?y) ?z)
    :when (
        (<
            (+ (* (nrows ?x) (* (ncols ?x) (nrows ?y))) (* (ncols ?x) (* (nrows ?y) (ncols ?z))))
            (+ (* (nrows ?y) (* (ncols ?y) (nrows ?z))) (* (ncols ?y) (* (nrows ?z) (ncols ?z))))
        )
    )
    :subsume
)


(rewrite
    (MatMul (MatMul ?x ?y) ?z)
    (MatMul ?x (MatMul ?y ?z))
    :when (
        (<
            (+ (* (nrows ?y) (* (ncols ?y) (nrows ?z))) (* (ncols ?y) (* (nrows ?z) (ncols ?z))))
            (+ (* (nrows ?x) (* (ncols ?x) (nrows ?y))) (* (ncols ?x) (* (nrows ?y) (ncols ?z))))
        )
    )
    :subsume
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants