Skip to content

Jax sparse-sparse matrix multiplication memory usage #17251

Answered by jakevdp
ivirshup asked this question in Q&A
Discussion options

You must be logged in to vote

Yes, this is a known issue: sparse-sparse matmul uses nse_1 * nse_2 memory complexity. Unfortunately, there's no sparse matrix primitives in XLA, so it's hard to do much better than this in general. It's one of the reasons that these tools have not graduated from the jax.experimental namespace.

I've actually thought about removing sparse-sparse matmul completely, because its performance tends to surprise people. What do you think?

Replies: 4 comments 14 replies

Comment options

You must be logged in to vote
7 replies
@jakevdp
Comment options

@ivirshup
Comment options

@jakevdp
Comment options

@jakevdp
Comment options

@ivirshup
Comment options

Answer selected by ivirshup
Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Comment options

You must be logged in to vote
4 replies
@jakevdp
Comment options

@wbrenton
Comment options

@jakevdp
Comment options

@wbrenton
Comment options

Comment options

You must be logged in to vote
2 replies
@ZedongPeng
Comment options

@jakevdp
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
5 participants