diff --git a/docs/composability.md b/docs/composability.md index 87c0b022..bc90ee91 100644 --- a/docs/composability.md +++ b/docs/composability.md @@ -24,4 +24,12 @@ One issue with seed checkpoints is that we rely on initializing _every_ model st We intentionally upcast the final output tensor to fp32 inside the loss function rather in the `Transformer.forward()` so that forward and backward casts can be fused with the loss forward and backward respectively when we `torch.compile()` the loss function. This can improve both throughput and memory usage. ## Setting `TORCH_NCCL_AVOID_RECORD_STREAMS=1` for TP -Users should set the environemnt variable `TORCH_NCCL_AVOID_RECORD_STREAMS=1` when using tensor parallelism (TP) to avoid unexpectedly high memory usage. `Tensor.record_stream` is a legacy approach for ensuring that a tensor allocated in one stream (e.g. default stream) and used in another stream (e.g. process group stream) is not freed before its usage completes. In particular, `record_stream` gets called for the `async_op=True` collectives used in TP. `Tensor.record_stream(stream)` records a CUDA event in the consumer stream `stream`, and the CUDA caching allocator queries this event for completion upon each future allocation to check if the tensor can be freed. This means that the tensor is not considered free in the caching allocator until the last GPU kernel before the recorded event finishes, which could be arbitrarily far in the future. For example, if the CPU is running far ahead of the GPU and issues `K` allocations before that last GPU kernel finishes, then none of those `K` allocations can reuse the memory from the tensor on which `record_stream` was called. Setting the environment variable `TORCH_NCCL_AVOID_RECORD_STREAMS=1` uses a simple alternative approach. The process group stashes references to the collective tensors until the user calls `wait()` on the collective. This should be intuitive: the collective input/output tensors cannot be freed until after the user calls `wait()`. +Users should set the environment variable `TORCH_NCCL_AVOID_RECORD_STREAMS=1` when using tensor parallelism (TP) to avoid unexpectedly high memory usage. + +TP uses async collectives (i.e. with `async_op=True`), such as all-gather, reduce-scatter, and all-reduce, to overlap communication with compute. Under the hood, an async collective runs the NCCL communication kernel in a separate CUDA stream, owned by the process group. Calling `wait()` on the returned work object has the current stream wait for the process group's stream, allowing the current stream to correctly use the result of the collective. + +This represents a producer-consumer pattern across streams: the collective tensors are produced in a compute stream (usually the default stream), and they are consumed in a communication stream (from the process group). Under such producer-consumer patterns across streams, we must ensure that the tensors are not freed before their usage in the consumer stream. + +[`Tensor.record_stream`](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) is a legacy approach for ensuring this. The process group will call `record_stream(comm_stream)` on the collective input and output tensors after issuing the collective kernel in the process group's `comm_stream`. This records a CUDA event in `comm_stream`, and the CUDA caching allocator that manages CUDA tensor memory in PyTorch will query this recorded event upon future allocations. Only once the event has completed, meaning that the collective has finished running, can the tensor memory be freed and considered for future reuse. This couples the caching allocator's memory reuse with _GPU kernel timing_, which does not otherwise happen without `record_stream`. While the collective kernel runs on GPU, any allocations made from the CPU for future ops cannot reuse that memory, even if we know that those future ops must run after the current collective. This inability to reuse leads to unexpected memory stacking. + +By setting `TORCH_NCCL_AVOID_RECORD_STREAMS=1`, the process group avoids calling `record_stream` on the collective tensors and instead uses a different approach. It simply stashes references to the collective tensors until the user calls `wait()` on the work object. Holding references ensures that the collective tensors will not be freed by the caching allocator. This can only lead to a memory regression if the user never calls `wait()`, where with `record_stream`, the caching allocator would still eventually free the collective tensors once the collective finishes on the GPU. Since this is not common or an expected usage, we recommend setting this environment variable.