Skip to content

Commit

Permalink
add sharded tensor test with empty shard
Browse files Browse the repository at this point in the history
  • Loading branch information
wz337 committed Oct 20, 2023
1 parent 39a66a6 commit 50ce3a5
Showing 1 changed file with 35 additions and 3 deletions.
38 changes: 35 additions & 3 deletions test/distributed/_shard/sharded_tensor/test_sharded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,32 @@ def test_shard_tensor(self):
self.assertEqual(torch.Size([3, 12]), local_shard.size())
self.assertEqual(torch.narrow(tensor, 0, 3 * self.rank, 3), local_shard)

@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_shard_tensor_with_empty_shard(self):
spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
],
)
tensor = torch.rand(9, 12).cuda(self.rank)
st = _shard_tensor(tensor, spec)

# Verify.
self.assertTrue(isinstance(st, sharded_tensor.ShardedTensor))
local_shard = st.local_tensor()
self.assertEqual(1, len(st.local_shards()))
if dist.get_rank() < 3:
self.assertEqual(torch.Size([3, 12]), local_shard.size())
self.assertEqual(torch.narrow(tensor, 0, 3 * self.rank, 3), local_shard)
else:
self.assertEqual(torch.Size([0, 12]), local_shard.size())

@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@requires_nccl()
Expand Down Expand Up @@ -1022,17 +1048,23 @@ def test_insufficient_sharding_dims(self):
self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
self.assertEqual((1, 20), local_shard.size())
else:
self.assertEqual(0, len(local_shards))
self.assertEqual(1, len(local_shards))
local_shard = local_shards[0].tensor
self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
self.assertEqual(local_shard.numel(), 0)

# Validate global metadata.
st_metadata = st.metadata()
shards_metadata = st_metadata.shards_metadata
self.assertEqual(2, len(shards_metadata))
self.assertEqual(4, len(shards_metadata))

for shard_rank, shard_metadata in enumerate(shards_metadata):
self.assertEqual([shard_rank, 0], shard_metadata.shard_offsets)
self.assertEqual([1, 20], shard_metadata.shard_sizes)
self.assertEqual(f'rank:{shard_rank}/cuda:{shard_rank}', str(shard_metadata.placement))
if shard_rank <= 1:
self.assertEqual([1, 20], shard_metadata.shard_sizes)
else:
self.assertEqual([0, 20], shard_metadata.shard_sizes)

@with_comms
@skip_if_lt_x_gpu(4)
Expand Down

0 comments on commit 50ce3a5

Please sign in to comment.