From 0c978e9ce1194ad122d4f15ea743c3ff354fc0a9 Mon Sep 17 00:00:00 2001 From: wangraying <45031995+wangraying@users.noreply.github.com> Date: Wed, 21 Jul 2021 17:19:51 +0800 Subject: [PATCH] feat: make full precision decentralized op stateless (#126) BREAKING CHANGE: `BaguaBucket.append_decentralized_synchronous_op` now only supports full precision decentralized communication. --- bagua/torch_api/algorithms/decentralized.py | 51 ++- bagua/torch_api/bucket.py | 132 ++++-- bagua/torch_api/distributed.py | 10 +- tests/torch_api/test_decentralized.py | 390 ++++++++++++++++++ .../test_low_precision_decentralized.py | 9 +- 5 files changed, 529 insertions(+), 63 deletions(-) create mode 100644 tests/torch_api/test_decentralized.py diff --git a/bagua/torch_api/algorithms/decentralized.py b/bagua/torch_api/algorithms/decentralized.py index 7de3500a3..ce45d9786 100644 --- a/bagua/torch_api/algorithms/decentralized.py +++ b/bagua/torch_api/algorithms/decentralized.py @@ -25,11 +25,16 @@ def __init__( weights are averaged in each communication step. "shift_one" means each worker selects a different peer to do weights average in each communication step. communication_interval (int): Number of iterations between two communication steps. + """ self.hierarchical = hierarchical self.peer_selection_mode = peer_selection_mode self.communication_interval = communication_interval + def _should_communicate(self, bagua_module: BaguaModule) -> bool: + cur_step = bagua_module.bagua_train_step_counter - 1 + return cur_step % self.communication_interval == 0 + def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]: parameters = bagua_module.bagua_build_params() self.tensors = [ @@ -40,8 +45,9 @@ def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]: def init_forward_pre_hook(self, bagua_module: BaguaModule): def hook(input): - for tensor in self.tensors: - tensor.bagua_mark_communication_ready() + if self._should_communicate(bagua_module): + for tensor in self.tensors: + tensor.bagua_mark_communication_ready() return hook @@ -53,23 +59,31 @@ def hook(parameter_name, parameter): def init_post_backward_hook(self, bagua_module: BaguaModule): def hook(): - bagua_module._bagua_backend.wait_pending_comm_ops() - torch.cuda.synchronize() - bagua_module._bagua_backend.execute_post_backward_comm_ops() - bagua_module._bagua_backend.wait_pending_post_backward_comm_ops() + if self._should_communicate(bagua_module): + bagua_module._bagua_backend.wait_pending_comm_ops() + for bucket in bagua_module.bagua_buckets: + bucket.decentralized_synchronous_op_copy_back_peer_weight( + hierarchical=self.hierarchical, peer_weight=bucket._peer_weight + ) return hook + def _init_states(self, bucket: BaguaBucket): + weight_tensor = bucket.flattened_tensor() + bucket._peer_weight = weight_tensor.to_bagua_tensor("peer_weight") + def init_operations( self, bagua_module: BaguaModule, bucket: BaguaBucket, ): + self._init_states(bucket) + torch.cuda.synchronize() bucket.clear_ops() bucket.append_decentralized_synchronous_op( + peer_weight=bucket._peer_weight, hierarchical=self.hierarchical, peer_selection_mode=self.peer_selection_mode, - communication_interval=self.communication_interval, ) @@ -87,6 +101,10 @@ def __init__(self, hierarchical: bool = True, communication_interval: int = 1): self.hierarchical = hierarchical self.communication_interval = communication_interval + def _should_communicate(self, bagua_module: BaguaModule) -> bool: + cur_step = bagua_module.bagua_train_step_counter - 1 + return cur_step % self.communication_interval == 0 + def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]: parameters = bagua_module.bagua_build_params() self.tensors = [ @@ -123,12 +141,13 @@ def hook(): def init_post_optimizer_step_hook(self, bagua_module: BaguaModule): def hook(optimizer: torch.optim.Optimizer): - for group in optimizer.param_groups: - for param in group["params"]: - if param.is_bagua_tensor(): - param.bagua_mark_communication_ready() + if self._should_communicate(bagua_module): + for group in optimizer.param_groups: + for param in group["params"]: + if param.is_bagua_tensor(): + param.bagua_mark_communication_ready() - bagua_module._bagua_backend.wait_pending_comm_ops() + bagua_module._bagua_backend.wait_pending_comm_ops() return hook @@ -153,12 +172,10 @@ def init_operations( self._init_states(bucket) torch.cuda.synchronize() bucket.clear_ops() - bucket.append_decentralized_synchronous_op( - hierarchical=self.hierarchical, - peer_selection_mode="ring", - communication_interval=self.communication_interval, - compression="MinMaxUInt8", + bucket.append_low_precision_decentralized_synchronous_op( weight=bucket._weight, left_peer_weight=bucket._left_peer_weight, right_peer_weight=bucket._right_peer_weight, + hierarchical=self.hierarchical, + compression="MinMaxUInt8", ) diff --git a/bagua/torch_api/bucket.py b/bagua/torch_api/bucket.py index 7388c14f2..7efad7d3c 100644 --- a/bagua/torch_api/bucket.py +++ b/bagua/torch_api/bucket.py @@ -9,6 +9,7 @@ from bagua.torch_api.tensor import BaguaTensor from bagua.torch_api.utils import check_contiguous +from bagua.torch_api.communication import broadcast class BaguaBucket: @@ -203,55 +204,40 @@ def append_centralized_synchronous_op( def append_decentralized_synchronous_op( self, + peer_weight: BaguaTensor, hierarchical: bool = True, peer_selection_mode: str = "all", - communication_interval: int = 1, - compression: Optional[str] = None, - weight: Optional[BaguaTensor] = None, - left_peer_weight: Optional[BaguaTensor] = None, - right_peer_weight: Optional[BaguaTensor] = None, ) -> BaguaBucket: """ Append a decentralized synchronous operation to a bucket. It will do gossipy style model averaging among workers. - The operations will be executed by the Bagua backend in the order they are appended - when all the tensors within the bucket are marked ready. + This operation is not inplace, which means the bucket weights is first copied to `peer_weight`, and the result of + decentralized averaging will be in `peer_weight`. To copy `peer_weight` back to `self`, call + :func:`decentralized_synchronous_op_copy_back_peer_weight`. + + This operation will be executed by the Bagua backend in + the order they are appended when all the tensors within the bucket are marked ready. Args: + peer_weight (BaguaTensor): A tensor used for averaging model with peers, should be of the same size + with the bucket tensors total size. Use ``self.flattened_tensor().to_bagua_tensor(...)`` to create such a tensor. hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine will communicate will each other first. After that, machines do inter-node communication. This can boost performance when the inter-node communication cost is high. - peer_selection_mode (str): Can be "all" or "shift_one" for full precision decentralized operation, "ring" for - low precision decentralized operation. "all" means all workers' weights are averaged in each communication step. - "shift_one" means each worker selects a different peer to do weights average in each communication step. - "ring" means all workers are connected into a ring, and each worker communicate with its neighbors. - communication_interval (int): Number of iterations between two communication steps. - compression: If not ``None``, the tensors will be compressed for communication. Currently "MinMaxUInt8" is - supported. - weight (BaguaTensor): Local model of current worker, a flattened tensor containing the same data as the local model - weights of current worker, required for low precision decentralized operation. - left_peer_weight (BaguaTensor): Model replica of current worker's connected left peer, a flattened tensor containing - the same data as model weights of left peer, required for low precision decentralized operation. - right_peer_weight (BaguaTensor): Model replica of current worker's connected right peer, similarly as `left_peer_weight`, - required for low precision decentralized operation. + peer_selection_mode (str): Can be "all" or "shift_one". "all" means all workers' weights are averaged + in each communication step. "shift_one" means each worker selects a different peer to do weights average + in each communication step. Returns: The bucket itself. """ + if hierarchical: self.backend_bucket.append_decentralized_synchronous_op( self._bagua_backend.internode_communicator, self._bagua_backend.intranode_communicator, hierarchical=hierarchical, peer_selection_mode=peer_selection_mode, - communication_interval=communication_interval, - compression=compression, - weight=weight._bagua_backend_tensor if weight is not None else None, - left_peer_weight=left_peer_weight._bagua_backend_tensor - if left_peer_weight is not None - else None, - right_peer_weight=right_peer_weight._bagua_backend_tensor - if right_peer_weight is not None - else None, + peer_weight=peer_weight._bagua_backend_tensor, ) else: self.backend_bucket.append_decentralized_synchronous_op( @@ -259,15 +245,87 @@ def append_decentralized_synchronous_op( None, hierarchical=hierarchical, peer_selection_mode=peer_selection_mode, - communication_interval=communication_interval, + peer_weight=peer_weight._bagua_backend_tensor, + ) + return self + + def decentralized_synchronous_op_copy_back_peer_weight( + self, peer_weight: BaguaTensor, hierarchical: bool = True + ): + """ + Copy `peer_weight` back to bucket weights to end a decentralized synchronous operation. + See :func:`append_decentralized_synchronous_op` for more information. + + Args: + peer_weight (BaguaTensor): A tensor used for averaging model with peers, should be of the same size + with the bucket tensors total size. Use ``self.flattened_tensor().to_bagua_tensor(...)`` to create such a tensor. + hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine + will communicate will each other first. After that, machines do inter-node communication. This can + boost performance when the inter-node communication cost is high. Must be the same with `hierarchical` argument in + :func:`append_decentralized_synchronous_op`. + """ + intra_comm = self._bagua_backend.intranode_communicator + inter_comm = self._bagua_backend.internode_communicator + + if not hierarchical or (inter_comm is not None): + self.backend_tensor.copy_(peer_weight) + + if hierarchical: + broadcast(self.backend_tensor, 0, intra_comm) + + def append_low_precision_decentralized_synchronous_op( + self, + weight: BaguaTensor, + left_peer_weight: BaguaTensor, + right_peer_weight: BaguaTensor, + hierarchical: bool = True, + compression: str = "MinMaxUInt8", + ) -> BaguaBucket: + """ + Append a low precision decentralized synchronous operation to a bucket. It will compress the difference + of local models between two successive iterations and exchange them among workers. + + The operations will be executed by the Bagua backend in the order they are appended + when all the tensors within the bucket are marked ready. + + Args: + weight (BaguaTensor): Model replica of current worker's local model. It should be of the same size + with the bucket tensors total size. Use ``self.flattened_tensor().to_bagua_tensor(...)`` to create such a tensor. + left_peer_weight (BaguaTensor): Model replica of current worker's left peer. It should be of the same size + with the bucket tensors total size. Use ``self.flattened_tensor().to_bagua_tensor(...)`` to create such a tensor, + then copy the initializing weights of current worker's left peer to the tensor. + right_peer_weight (BaguaTensor): Model replica of current worker's right peer. It should be of the same size + with the bucket tensors total size. Use ``self.flattened_tensor().to_bagua_tensor(...)`` to create such a tensor. + then copy the initializing weights of current worker's right peer to the tensor. + hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine + will communicate will each other first. After that, machines do inter-node communication. This can + boost performance when the inter-node communication cost is high. + compression (str): The way how tensors are compressed for communication. Currently "MinMaxUInt8" is supported. + Returns: + The bucket itself. + """ + + if hierarchical: + self.backend_bucket.append_low_precision_decentralized_synchronous_op( + self._bagua_backend.internode_communicator, + self._bagua_backend.intranode_communicator, + hierarchical=hierarchical, + peer_selection_mode="ring", + compression=compression, + weight=weight._bagua_backend_tensor, + left_peer_weight=left_peer_weight._bagua_backend_tensor, + right_peer_weight=right_peer_weight._bagua_backend_tensor, + ) + else: + self.backend_bucket.append_low_precision_decentralized_synchronous_op( + self._bagua_backend.global_communicator, + None, + hierarchical=hierarchical, + peer_selection_mode="ring", compression=compression, - weight=weight._bagua_backend_tensor if weight is not None else None, - left_peer_weight=left_peer_weight._bagua_backend_tensor - if left_peer_weight is not None - else None, - right_peer_weight=right_peer_weight._bagua_backend_tensor - if right_peer_weight is not None - else None, + weight=weight._bagua_backend_tensor, + left_peer_weight=left_peer_weight._bagua_backend_tensor, + right_peer_weight=right_peer_weight._bagua_backend_tensor, ) return self diff --git a/bagua/torch_api/distributed.py b/bagua/torch_api/distributed.py index 26a8d79c2..a5867a930 100644 --- a/bagua/torch_api/distributed.py +++ b/bagua/torch_api/distributed.py @@ -198,6 +198,7 @@ def with_bagua( # pytype: disable=module-attr self, "_ddp_params_and_buffers_to_ignore" ): # for compatibility with PyTorch DDP self.parameters_to_ignore.extend(self._ddp_params_and_buffers_to_ignore) + self.bagua_train_step_counter = 0 """ Number of iterations in training mode. @@ -271,12 +272,8 @@ def record_speed_metrics_event(self, _): ) # get communicators - self._bagua_inter_node_communicator = ( - self._bagua_backend.internode_communicator - ) - self._bagua_intra_node_communicator = ( - self._bagua_backend.intranode_communicator - ) + self._bagua_inter_node_communicator = self._bagua_backend.internode_communicator + self._bagua_intra_node_communicator = self._bagua_backend.intranode_communicator self._bagua_global_communicator = self._bagua_backend.global_communicator self.bagua_communication_stream = self._bagua_backend.stream @@ -388,6 +385,7 @@ def real_post_backward_hook(*unused): def new_step_factory(optimizer): def new_step(self, *args, **kwargs): result = self._bagua_original_step(*args, **kwargs) + optimizer_hook(self) return result diff --git a/tests/torch_api/test_decentralized.py b/tests/torch_api/test_decentralized.py new file mode 100644 index 000000000..46a50e4f2 --- /dev/null +++ b/tests/torch_api/test_decentralized.py @@ -0,0 +1,390 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from tests.internal.common_utils import find_free_port +import unittest +import torch.multiprocessing as mp +import os +from bagua.torch_api.utils import flatten, unflatten +import bagua.torch_api as bagua + + +N_EPOCHS = 10 + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.fc1 = nn.Linear(2, 10, bias=False) + self.fc2 = nn.Linear(10, 50, bias=True) + self.fc3 = nn.Linear(50, 4, bias=False) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.fc3(x) + return F.softmax(x, dim=1) + + +def _init_env(rank): + # set deterministic + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.manual_seed(rank) + # initialize subprocess env + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + + +def run_model( + rank, nprocs, hierarchical, peer_selection_mode, communication_interval, results +): + _init_env(rank) + + # init bagua distributed process group + torch.cuda.set_device(rank) + bagua.init_process_group() + + # construct model and optimizer, etc. + model = Net().cuda() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + loss_fn = nn.MSELoss() + + # wrap model + model = model.with_bagua( + [optimizer], + bagua.algorithms.decentralized.DecentralizedAlgorithm( + hierarchical=hierarchical, + peer_selection_mode=peer_selection_mode, + communication_interval=communication_interval, + ), + ) + + ret = results[rank] + + ret.init_weight.copy_(flatten([param.data for param in model.parameters()])) + + for epoch in range(N_EPOCHS): + data = torch.randn(4, 2).cuda() + target = torch.randn(4, 4).cuda() + + optimizer.zero_grad() + output = model(data) + loss = loss_fn(output, target) + + loss.backward() + optimizer.step() + + ret.bucket_weight.copy_(model.bagua_buckets[0]._peer_weight) + + +def run_torch_model( + rank, + nprocs, + hierarchical, + peer_selection_mode, + communication_interval, + results, + backend, +): + _init_env(rank) + + # init torch distributed process group + torch.cuda.set_device(rank) + store = torch.distributed.FileStore("/tmp/filestore", nprocs) + torch.distributed.init_process_group( + world_size=nprocs, rank=rank, store=store, backend=backend + ) + + # construct model and optimizer, etc. + model = Net().cuda() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + loss_fn = nn.MSELoss() + + # wrap model + model = DecentralizedAlgor( + model, optimizer, hierarchical, peer_selection_mode, communication_interval + ) + + ret = results[rank] + ret.init_weight.copy_(flatten([param.data for param in model.parameters()])) + + for epoch in range(N_EPOCHS): + data = torch.randn(4, 2).cuda() + target = torch.randn(4, 4).cuda() + + optimizer.zero_grad() + output = model(data) + loss = loss_fn(output, target) + + loss.backward() + model.step() + + ret.bucket_weight.copy_(model.peer_weight) + + +class Result(object): + def __init__(self): + model = Net() + self.init_weight = flatten( + [torch.zeros_like(param.data) for param in model.parameters()] + ) + self.bucket_weight = flatten( + [torch.zeros_like(param.data) for param in model.parameters()] + ) + + +class DecentralizedAlgor(nn.Module): + def __init__( + self, + module, + optimizer, + hierarchical, + peer_selection_mode, + communication_interval, + ): + super(DecentralizedAlgor, self).__init__() + self.module = module + self.optimizer = optimizer + self.hierarchical = hierarchical + self.peer_selection_mode = peer_selection_mode + self.communication_interval = communication_interval + self.step_count = 0 + + assert torch.distributed.is_initialized() + + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + # broadcast parameters + for param in self.module.parameters(): + torch.distributed.broadcast(param.data, src=0) + + def _build_params(self): + return [param.data for param in list(self.module.parameters()).__reversed__()] + + def communicate_with_peer(self): + if self.peer_selection_mode == "all": + torch.distributed.all_reduce(self.peer_weight) + self.peer_weight /= self.world_size + elif self.peer_selection_mode == "shift_one": + peer_rank = get_peer_rank( + self.peer_selection_mode, + self.rank, + self.world_size, + self.step_count, + self.communication_interval, + ) + + weight = self.weight.cpu() + peer_weight = self.peer_weight.cpu() + + requests = [] + requests.append(torch.distributed.isend(weight, peer_rank)) + requests.append(torch.distributed.irecv(peer_weight, peer_rank)) + + for req in requests: + req.wait() + + self.peer_weight = peer_weight.cuda() + self.weight = weight.cuda() + + self.peer_weight += self.weight + self.peer_weight /= 2 + else: + raise ValueError("Unsupported `peer_selection_mode`") + + def _should_communicate(self): + return self.step_count % self.communication_interval == 0 + + def forward(self, *inputs, **kwargs): + if self._should_communicate(): + self.weight = flatten(self._build_params()) + self.peer_weight = flatten(self._build_params()) + self.communicate_with_peer() + + result = self.module(*inputs, **kwargs) + return result + + def step(self): + if self._should_communicate(): + params = self._build_params() + for buf, synced in zip(params, unflatten(self.peer_weight, params)): + buf.copy_(synced) + + self.optimizer.step() + self.step_count += 1 + + +def get_peer_rank(peer_selection_mode, rank, nranks, step, communication_interval): + comm_step = step // communication_interval + if peer_selection_mode == "shift_one": + if rank < nranks // 2: + return ((comm_step + rank) % ((nranks + 1) // 2)) + (nranks // 2) + else: + return (rank - (nranks // 2) - comm_step) % (nranks // 2) + else: + ValueError("Unsupported `peer_selection_mode`") + + +class TestLowPrecisionDecentralized(unittest.TestCase): + def run_test_locally( + self, nprocs, hierarchical, peer_selection_mode, communication_interval + ): + if not torch.cuda.is_available(): + print("skip tests since cuda is not available") + return + + nprocs = torch.cuda.device_count() + os.environ["WORLD_SIZE"] = str(nprocs) + os.environ["LOCAL_WORLD_SIZE"] = str(nprocs) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(find_free_port()) + os.environ["BAGUA_SERVICE_PORT"] = str(find_free_port()) + + results = [Result() for _ in range(nprocs)] + mp.spawn( + run_model, + nprocs=nprocs, + args=( + nprocs, + hierarchical, + peer_selection_mode, + communication_interval, + results, + ), + ) + + for rank in range(nprocs): + if peer_selection_mode == "all": + peer_rank = (rank + 1) % nprocs + # all workers have equal weights + self.assertTrue( + torch.equal( + results[rank].bucket_weight, + results[peer_rank].bucket_weight, + ) + ) + elif peer_selection_mode == "shift_one": + peer_rank = get_peer_rank( + peer_selection_mode, + rank, + nprocs, + N_EPOCHS - 1, + communication_interval, + ) + + self.assertTrue( + torch.equal( + results[rank].bucket_weight, results[peer_rank].bucket_weight + ) + ) + else: + raise ValueError("illegal `peer_selection_mode`!") + + def run_diff_locally( + self, nprocs, hierarchical, peer_selection_mode, communication_interval, backend + ): + if not torch.cuda.is_available(): + print("skip tests since cuda is not available") + return + + os.environ["WORLD_SIZE"] = str(nprocs) + os.environ["LOCAL_WORLD_SIZE"] = str(nprocs) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(find_free_port()) + os.environ["BAGUA_SERVICE_PORT"] = str(find_free_port()) + + torch_results = [Result() for _ in range(nprocs)] + mp.spawn( + run_torch_model, + nprocs=nprocs, + args=( + nprocs, + hierarchical, + peer_selection_mode, + communication_interval, + torch_results, + backend, + ), + ) + + bagua_results = [Result() for _ in range(nprocs)] + mp.spawn( + run_model, + nprocs=nprocs, + args=( + nprocs, + hierarchical, + peer_selection_mode, + communication_interval, + bagua_results, + ), + ) + + for rank in range(nprocs): + self.assertTrue( + torch.all( + torch.isclose( + bagua_results[rank].init_weight, + torch_results[rank].init_weight, + ) + ).item() + ) + + self.assertTrue( + torch.all( + torch.isclose( + bagua_results[rank].bucket_weight, + torch_results[rank].bucket_weight, + ) + ).item() + ) + + def test_algorithm(self): + self.run_test_locally( + nprocs=8, + hierarchical=False, + peer_selection_mode="all", + communication_interval=1, + ) + self.run_test_locally( + nprocs=8, + hierarchical=False, + peer_selection_mode="shift_one", + communication_interval=1, + ) + self.run_test_locally( + nprocs=8, + hierarchical=False, + peer_selection_mode="shift_one", + communication_interval=2, + ) + + def test_compare(self): + self.run_diff_locally( + nprocs=8, + hierarchical=False, + peer_selection_mode="all", + communication_interval=1, + backend="gloo", + ) + self.run_diff_locally( + nprocs=8, + hierarchical=False, + peer_selection_mode="shift_one", + communication_interval=1, + backend="gloo", + ) + self.run_diff_locally( + nprocs=8, + hierarchical=False, + peer_selection_mode="shift_one", + communication_interval=2, + backend="gloo", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/torch_api/test_low_precision_decentralized.py b/tests/torch_api/test_low_precision_decentralized.py index 06fb20e56..bcd96ec58 100644 --- a/tests/torch_api/test_low_precision_decentralized.py +++ b/tests/torch_api/test_low_precision_decentralized.py @@ -51,7 +51,8 @@ def run_model(rank, nprocs, hierarchical, communication_interval, results): model = model.with_bagua( [optimizer], bagua.algorithms.decentralized.LowPrecisionDecentralizedAlgorithm( - hierarchical=hierarchical, communication_interval=communication_interval + hierarchical=hierarchical, + communication_interval=communication_interval, ), ) @@ -156,6 +157,9 @@ def __init__(self, module, optimizer, hierarchical, communication_interval): def _build_params(self): return [param.data for param in list(self.module.parameters()).__reversed__()] + def _should_communicate(self): + return self.step_count % self.communication_interval == 0 + def forward(self, *inputs, **kwargs): result = self.module(*inputs, **kwargs) return result @@ -217,7 +221,7 @@ def hierarchical_update_weight_fn(x): torch.distributed.broadcast(x, 0) - if self.step_count % self.communication_interval == 0: + if self._should_communicate(): weights = self._build_params() if self.hierarchical: apply_flattened_call(weights, hierarchical_update_weight_fn) @@ -225,7 +229,6 @@ def hierarchical_update_weight_fn(x): apply_flattened_call( weights, lambda x: update_weight_fn(x, self.world_size) ) - self.step_count += 1