Skip to content

Commit

Permalink
Added support for all_gather object (pytorch#3047)
Browse files Browse the repository at this point in the history
* Added support for all_gather object

* Apply suggestions from code review

Co-authored-by: Sadra Barikbin <sadraqazvin1@yahoo.com>

* Added new test in _test_distrib_all_gather_group

* Handling pytorch old versions

---------

Co-authored-by: Sadra Barikbin <sadraqazvin1@yahoo.com>
  • Loading branch information
vfdev-5 and sadra-barikbin authored Aug 31, 2023
1 parent e3c625a commit 34a707e
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 15 deletions.
13 changes: 10 additions & 3 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _apply_op(
return tensor

def _collective_op(
self, tensor: Union[torch.Tensor, float, str], fn: Callable, *args: Any, **kwargs: Any
self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args: Any, **kwargs: Any
) -> Union[torch.Tensor, float, List[float], List[str]]:
tensor_to_number = tensor_to_str = False
device = self.device()
Expand Down Expand Up @@ -216,10 +216,10 @@ def all_reduce(
return cast(Union[torch.Tensor, float], self._collective_op(tensor, self._do_all_reduce, op, group=group))

def all_gather(
self, tensor: Union[torch.Tensor, float, str], group: Optional[Any] = None
self, tensor: Union[torch.Tensor, float, str, Any], group: Optional[Any] = None
) -> Union[torch.Tensor, float, List[float], List[str]]:
if not isinstance(tensor, (torch.Tensor, Number, str)):
raise TypeError(f"Unhandled input type {type(tensor)}")
return self._do_all_gather_object(tensor, group=group)

return self._collective_op(tensor, self._do_all_gather, group=group)

Expand Down Expand Up @@ -282,6 +282,10 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
pass

@abstractmethod
def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
pass

@abstractmethod
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
pass
Expand Down Expand Up @@ -373,6 +377,9 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
return tensor

def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> Any:
return tensor

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
return ranks

Expand Down
6 changes: 6 additions & 0 deletions ignite/distributed/comp_models/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t
tensor = tensor.unsqueeze(0)
return hvd.allgather(tensor)

def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
if group is not None:
raise NotImplementedError("all_gather with group for horovod is not implemented")

return hvd.allgather_object(tensor)

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
return hvd.ProcessSet(ranks)

Expand Down
30 changes: 29 additions & 1 deletion ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
if group is not None and not isinstance(group, dist.ProcessGroup):
raise ValueError("Argument group should be list of int or ProcessGroup")
reduce_op = self._reduce_op_map[op]
# We do if/else here for compatibility with older pytorch versions
if group is not None:
dist.all_reduce(tensor, reduce_op, group=group)
else:
Expand All @@ -432,7 +433,8 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
if group == dist.GroupMember.NON_GROUP_MEMBER:
return tensor
elif group is None:

if group is None:
group_size = self.get_world_size()
elif isinstance(group, dist.ProcessGroup):
group_size = group.size()
Expand All @@ -441,12 +443,38 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t
if tensor.ndimension() == 0:
tensor = tensor.unsqueeze(0)
output = [torch.zeros_like(tensor) for _ in range(group_size)]
# We do if/else here for compatibility with older pytorch versions
if group is not None:
dist.all_gather(output, tensor, group=group)
else:
dist.all_gather(output, tensor)
return torch.cat(output, dim=0)

def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
if Version(torch.__version__) < Version("1.7.0"):
raise RuntimeError(
"Current torch version does not implement dist.all_gather_object. "
"Required version should be >=1.7.0"
)

if group == dist.GroupMember.NON_GROUP_MEMBER:
return tensor

if group is None:
group_size = self.get_world_size()
elif isinstance(group, dist.ProcessGroup):
group_size = group.size()
else:
raise ValueError("Argument group should be list of int or ProcessGroup")
output = [None for _ in range(group_size)]
# We do if/else here for compatibility with older pytorch versions
if group is not None:
dist.all_gather_object(output, tensor, group=group)
else:
dist.all_gather_object(output, tensor)

return output

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
return dist.new_group(ranks=ranks, **kwargs)

Expand Down
3 changes: 3 additions & 0 deletions ignite/distributed/comp_models/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t
xm.all_reduce("sum", [output], groups=group)
return output.reshape(-1, *output.shape[2:])

def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
raise NotImplementedError("all_gather on object is not implemented for xla")

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
return [ranks]

Expand Down
76 changes: 65 additions & 11 deletions tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,21 +156,22 @@ def _test_distrib_all_reduce_group(device):

def _test_distrib_all_gather(device):
rank = idist.get_rank()
ws = idist.get_world_size()

res = torch.tensor(idist.all_gather(10), device=device)
true_res = torch.tensor([10] * idist.get_world_size(), device=device)
true_res = torch.tensor([10] * ws, device=device)
assert (res == true_res).all()

t = torch.tensor(rank, device=device)
res = idist.all_gather(t)
true_res = torch.tensor([i for i in range(idist.get_world_size())], device=device)
true_res = torch.tensor([i for i in range(ws)], device=device)
assert (res == true_res).all()

x = "test-test"
if rank == 0:
x = "abc"
res = idist.all_gather(x)
true_res = ["abc"] + ["test-test"] * (idist.get_world_size() - 1)
true_res = ["abc"] + ["test-test"] * (ws - 1)
assert res == true_res

base_x = "tests/ignite/distributed/utils/test_native.py" * 2000
Expand All @@ -179,27 +180,46 @@ def _test_distrib_all_gather(device):
x = "abc"

res = idist.all_gather(x)
true_res = ["abc"] + [base_x] * (idist.get_world_size() - 1)
true_res = ["abc"] + [base_x] * (ws - 1)
assert res == true_res

t = torch.arange(100, device=device).reshape(4, 25) * (rank + 1)
in_dtype = t.dtype
res = idist.all_gather(t)
assert res.shape == (idist.get_world_size() * 4, 25)
assert res.shape == (ws * 4, 25)
assert res.dtype == in_dtype
true_res = torch.zeros(idist.get_world_size() * 4, 25, device=device)
for i in range(idist.get_world_size()):
true_res = torch.zeros(ws * 4, 25, device=device)
for i in range(ws):
true_res[i * 4 : (i + 1) * 4, ...] = torch.arange(100, device=device).reshape(4, 25) * (i + 1)
assert (res == true_res).all()

if idist.get_world_size() > 1:
with pytest.raises(TypeError, match=r"Unhandled input type"):
idist.all_reduce([0, 1, 2])
if ws > 1 and idist.backend() != "xla-tpu":
t = {
"a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)],
"b": torch.tensor([[rank + 1, rank + 2, rank + 3]], device=device),
"c": {"abcd": rank, "cdfg": torch.tensor(rank, dtype=torch.uint8, device=device)},
}
res = idist.all_gather(t)
assert isinstance(res, list) and len(res) == ws
for i, obj in enumerate(res):
assert isinstance(obj, dict)
assert list(obj.keys()) == ["a", "b", "c"], obj
expected_device = (
device if torch.device(device).type == "cpu" else torch.device(f"{torch.device(device).type}:{i}")
)
expected = {
"a": [i + 1, i + 2, torch.tensor(i + 3, device=expected_device)],
"b": torch.tensor([[i + 1, i + 2, i + 3]], device=expected_device),
"c": {"abcd": i, "cdfg": torch.tensor(i, dtype=torch.uint8, device=expected_device)},
}
assert obj["a"] == expected["a"]
assert (obj["b"] == expected["b"]).all()
assert obj["c"] == expected["c"]


def _test_distrib_all_gather_group(device):
if idist.get_world_size() > 1:
ranks = [0, 1]
ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1]
rank = idist.get_rank()
bnd = idist.backend()

Expand All @@ -226,6 +246,40 @@ def _test_distrib_all_gather_group(device):
else:
assert res == t

t = {
"a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)],
"b": torch.tensor([[rank + 1, rank + 2, rank + 3]], device=device),
"c": {"abcd": rank, "cdfg": torch.tensor(rank, dtype=torch.uint8, device=device)},
}
if bnd in ("xla-tpu"):
with pytest.raises(NotImplementedError, match=r"all_gather on object is not implemented for xla"):
res = idist.all_gather(t, group=ranks)
elif bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
res = idist.all_gather(t, group=ranks)
else:
res = idist.all_gather(t, group=ranks)
if rank in ranks:
assert isinstance(res, list) and len(res) == len(ranks)
for i, obj in zip(ranks, res):
assert isinstance(obj, dict)
assert list(obj.keys()) == ["a", "b", "c"], obj
expected_device = (
device
if torch.device(device).type == "cpu"
else torch.device(f"{torch.device(device).type}:{i}")
)
expected = {
"a": [i + 1, i + 2, torch.tensor(i + 3, device=expected_device)],
"b": torch.tensor([[i + 1, i + 2, i + 3]], device=expected_device),
"c": {"abcd": i, "cdfg": torch.tensor(i, dtype=torch.uint8, device=expected_device)},
}
assert obj["a"] == expected["a"], (obj, expected)
assert (obj["b"] == expected["b"]).all(), (obj, expected)
assert obj["c"] == expected["c"], (obj, expected)
else:
assert res == t

if bnd in ("nccl", "gloo", "mpi"):
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
res = idist.all_gather(t, group="abc")
Expand Down
3 changes: 3 additions & 0 deletions tests/ignite/distributed/utils/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import torch
import torch.distributed as dist
from packaging.version import Version

import ignite.distributed as idist
from ignite.distributed.utils import has_native_dist_support
Expand Down Expand Up @@ -236,6 +237,7 @@ def test_idist_all_reduce_gloo(distributed_context_single_node_gloo):
@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="dist.all_gather_object is not implemented")
def test_idist_all_gather_nccl(distributed_context_single_node_nccl):
device = idist.device()
_test_distrib_all_gather(device)
Expand All @@ -244,6 +246,7 @@ def test_idist_all_gather_nccl(distributed_context_single_node_nccl):

@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="dist.all_gather_object is not implemented")
def test_idist_all_gather_gloo(distributed_context_single_node_gloo):
device = idist.device()
_test_distrib_all_gather(device)
Expand Down

0 comments on commit 34a707e

Please sign in to comment.