diff --git a/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py index 9043975f49643f..4289adf974972d 100644 --- a/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py +++ b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py @@ -241,6 +241,32 @@ def test_shard_tensor(self): self.assertEqual(torch.Size([3, 12]), local_shard.size()) self.assertEqual(torch.narrow(tensor, 0, 3 * self.rank, 3), local_shard) + @with_comms(init_rpc=False) + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_shard_tensor_with_empty_shard(self): + spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + tensor = torch.rand(9, 12).cuda(self.rank) + st = _shard_tensor(tensor, spec) + + # Verify. + self.assertTrue(isinstance(st, sharded_tensor.ShardedTensor)) + local_shard = st.local_tensor() + self.assertEqual(1, len(st.local_shards())) + if dist.get_rank() < 3: + self.assertEqual(torch.Size([3, 12]), local_shard.size()) + self.assertEqual(torch.narrow(tensor, 0, 3 * self.rank, 3), local_shard) + else: + self.assertEqual(torch.Size([0, 12]), local_shard.size()) + @with_comms(init_rpc=False) @skip_if_lt_x_gpu(4) @requires_nccl() @@ -1022,17 +1048,23 @@ def test_insufficient_sharding_dims(self): self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) self.assertEqual((1, 20), local_shard.size()) else: - self.assertEqual(0, len(local_shards)) + self.assertEqual(1, len(local_shards)) + local_shard = local_shards[0].tensor + self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + self.assertEqual(local_shard.numel(), 0) # Validate global metadata. st_metadata = st.metadata() shards_metadata = st_metadata.shards_metadata - self.assertEqual(2, len(shards_metadata)) + self.assertEqual(4, len(shards_metadata)) for shard_rank, shard_metadata in enumerate(shards_metadata): self.assertEqual([shard_rank, 0], shard_metadata.shard_offsets) - self.assertEqual([1, 20], shard_metadata.shard_sizes) self.assertEqual(f'rank:{shard_rank}/cuda:{shard_rank}', str(shard_metadata.placement)) + if shard_rank <= 1: + self.assertEqual([1, 20], shard_metadata.shard_sizes) + else: + self.assertEqual([0, 20], shard_metadata.shard_sizes) @with_comms @skip_if_lt_x_gpu(4)