Skip to content

Latest commit

 

History

History
215 lines (156 loc) · 7.71 KB

ddp_comm_hooks.rst

File metadata and controls

215 lines (156 loc) · 7.71 KB

DDP Communication Hooks

DDP communication hook is a generic interface to control how to communicate gradients across workers by overriding the vanilla allreduce in DistributedDataParallel. A few built-in communication hooks are provided, and users can easily apply any of these hooks to optimize communication. Besides, the hook interface can also support user-defined communication strategies for more advanced use cases.

How to Use a Communication Hook?

To use a communication hook, the user just needs to let the DDP model register the hook before the training loop as below.

:func:`torch.nn.parallel.DistributedDataParallel.register_comm_hook`

What Does a Communication Hook Operate On?

A communication hook provides a flexible way to allreduce gradients. Therefore, it mainly operates on the gradients on each replica before allreduce, which are bucketized to increase the overlap between communication and computation. Particularly, :class:`torch.distributed.GradBucket` represents a bucket of gradient tensors to be allreduced.

.. autoclass:: torch.distributed.GradBucket

.. autofunction:: torch.distributed.GradBucket.index
.. autofunction:: torch.distributed.GradBucket.buffer
.. autofunction:: torch.distributed.GradBucket.gradients
.. autofunction:: torch.distributed.GradBucket.is_last
.. autofunction:: torch.distributed.GradBucket.set_buffer
.. autofunction:: torch.distributed.GradBucket.parameters

Default Communication Hooks

Default communication hooks are simple stateless hooks, so the input state in register_comm_hook is either a process group or None. The input bucket is a :class:`torch.distributed.GradBucket` object.

.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks
.. autofunction:: allreduce_hook
.. autofunction:: fp16_compress_hook
.. autofunction:: bf16_compress_hook

Additionally, a communication hook wrapper is provided to support :meth:`~fp16_compress_hook` or :meth:`~bf16_compress_hook` as a wrapper, which can be combined with other communication hooks.

.. autofunction:: fp16_compress_wrapper
.. autofunction:: bf16_compress_wrapper

PowerSGD Communication Hook

PowerSGD (Vogels et al., NeurIPS 2019) is a gradient compression algorithm, which can provide very high compression rates and accelerate bandwidth-bound distributed training. This algorithm needs to maintain both some hyperparameters and the internal state. Therefore, PowerSGD communication hook is a stateful hook, and the user needs to provide a state object defined as below.

PowerSGD State

.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook
.. autoclass:: PowerSGDState

PowerSGD Hooks

Warning

PowerSGD typically requires extra memory of the same size as the model's gradients to enable error feedback, which can compensate for biased compressed communication and improve accuracy.

Warning

PowerSGD hooks may conflict with Apex automatic mixed precision package. Please use PyTorch native automatic mixed precision package instead.

.. autofunction:: powerSGD_hook
.. autofunction:: batched_powerSGD_hook

Debugging Communication Hooks

As the name implies, debugging communication hooks are only used for debugging and performance optimization purpose.

.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks

Warning

Debugging communication hooks do not necessarily output the correct results.

.. autofunction:: noop_hook

Checkpointing of Communication Hooks

.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook

A stateful communication hook can be saved as a part of model checkpointing to enable trainer restarts. To make a hook serializable, __setstate__ and __getstate__ should be defined.

Warning

__getstate__ should exclude non-serializable attributes from a returned dictionary.

Warning

__setstate__ should properly initialize non-serializable attributes, excluded from a provided state.

:class:`PowerSGDState` has __setstate__ and __getstate__ implemented and can be used as a reference.

.. automethod:: PowerSGDState.__getstate__
.. automethod:: PowerSGDState.__setstate__

Here is a simple, end-to-end example of saving and reloading PowerSGD state and hook.

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(24,24)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(24,12)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def run_demo(demo_fn, world_size):
    mp.spawn(
        demo_fn,
        args=(world_size,),
        nprocs=world_size,
        join=True)

def demo_serialization(rank, world_size):
    setup(rank, world_size)

    CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt"

    model = SimpleModel().to(rank)
    ddp_model = DistributedDataParallel(model, device_ids=[rank])

    powersgd_hook = powerSGD.powerSGD_hook
    powersgd_state = powerSGD.PowerSGDState(process_group=None)

    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    ddp_model.register_comm_hook(powersgd_state, powersgd_hook)

    state = {
        'state_dict': ddp_model.state_dict(),
        'comm_hook': powersgd_hook,
        'comm_hook_state': powersgd_state}

    if rank == 0:
        torch.save(state, CHECKPOINT)

    dist.barrier()
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    checkpoint = torch.load(CHECKPOINT, map_location=map_location)

    new_ddp_model = DistributedDataParallel(SimpleModel().to(rank), device_ids=[rank])
    new_ddp_model.load_state_dict(checkpoint['state_dict'])
    powersgd_hook = checkpoint['comm_hook']
    powersgd_state = checkpoint['comm_hook_state']

    new_ddp_model.register_comm_hook(powersgd_state, powersgd_hook)

    if rank == 0:
        os.remove(CHECKPOINT)

    cleanup()

if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
    world_size = n_gpus
    run_demo(demo_serialization, world_size)

Acknowledgements

Many thanks to PowerSGD paper author Thijs Vogels for the code review on PowerSGD communication hook, as well as the comparison experiments, which show that the performance of PowerSGD communication hook is on par with the implementation in the original paper.