Skip to content

How do I do this rather trivial parallelization? #23613

Answered by yashk2810
sokol11 asked this question in Q&A
Discussion options

You must be logged in to vote

Mesh defines an nd grid of devices that you can shard your array on. With pmap, you could only shard 1 dimension at a time and to do multi-dimension sharding, you would need to nest pmaps (which became very complicated).

With LLMs, you often shard multiple dimensions different ways which is why you need an abstraction that can define such a sharding in a very straightforward way. This is where Mesh comes into play.

PartitionSpec defines how an array should be sharded given the axis of the meshes. For example:

# Here you have a 1d grid of devices with an axis name `x`. Imagine there are 8 devices.
mesh = Mesh(jax.devices(), ('x',))

# Here you are sharding the numpy array on mesh axis `x`.…

Replies: 1 comment 12 replies

Comment options

You must be logged in to vote
12 replies
@sokol11
Comment options

@sokol11
Comment options

@sokol11
Comment options

@yashk2810
Comment options

Answer selected by sokol11
@yashk2810
Comment options

@sokol11
Comment options

@yashk2810
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants