Overview
Practical tips to improve memory usage, computation/communication overlap, and kernel efficiency when developing with PyTorch.
1. Prefer all_gather_base over all_gather when possible
Using all_gather_base reduces memory fragmentation and runtime overhead compared to gathering into a Python list and concatenating.
output = torch.empty(input.numel() * world_size, dtype=input.dtype, device=input.device)
torch.distributed.all_gather_base(output, input, group=xxx)
output_list = [
torch.empty(input.numel(), dtype=input.dtype, device=input.device)
for _ in range(world_size)
]
torch.distributed.all_gather(output_list, input, group=xxx)
output = torch.cat(output_list, dim=0)
2. Use specialized operators rather than generic ones
For example, prefer F.embedding to index_select. Index-select implementations may involve host-side index expansion and memory handling that reduce efficiency.
3. Share contiguous buffers for long-lived tensors
Allocating a large contiguous buffer and slicing into it for long-lived tensors reduces fragmentation and operator dispatch overhead.
data = torch.zeros(global_size, dtype=xx, device=xx)
start_idx = 0
for i in range(len(item_list)):
item_list[i] = data[start_idx:start_idx + item_list[i].numel()].view(item_list[i].shape)
torch.cuda.empty_cache() # free original released item list data
CUDA memory pools allocate aligned blocks; using scattered blocks increases fragmentation. Operating on large contiguous buffers also reduces the number of kernel launches and improves throughput.
4. Use asynchronous communication to increase compute/communication overlap
comm_handle = torch.distributed.all_reduce(data, group=xxx, async_op=True)
# ... perform other computations ...
comm_handle.wait()
With proper topology planning, intermediate computations can overlap with communication.
5. Use expandable segments for workloads with frequently changing allocation sizes
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
This mechanism operates on the virtual address space, allowing the driver to map additional physical memory after an allocated block, which helps segments grow and can reduce fragmentation and improve cache matching.
6. Free caches at appropriate times to reduce memory footprint
torch.cuda.empty_cache()
Clear temporary device tensors created during initialization before training starts to avoid pooling fragmentation. Do not call torch.cuda.empty_cache() frequently during training unless switching tasks (e.g., train/eval), because releasing cache blocks triggers stream synchronization and can be expensive.
7. Non-blocking host-to-device copies are safe and useful
data = data.cuda(non_blocking=True)
Synchronization points are automatically inserted where subsequent operations depend on the data, while non-blocking copies allow overlap of H2D transfers with computation when there is no immediate dependency.
8. Utilize idle CPU cycles
Move suitable data preprocessing steps to the CPU to balance workload. There is often optimization potential in current code paths, such as in Megatron-LM's master branch.
9. Accelerate communication operator memory release
Optimize memory release for communication operators to reduce overhead.
10. Avoid hitting memory capacity limits during training/inference
If observed memory usage fluctuates persistently, it likely indicates hitting the memory ceiling. Frequent pool reclamation triggers stream synchronizations for each block release, causing significant performance degradation even if average utilization appears high.
11. Use NvFuser for chained elementwise operations
Annotate functions with TorchScript to enable NvFuser to generate fused kernels for forward and backward passes. All operations inside a torch.jit.script-decorated function must be compatible with TorchScript.
@torch.jit.script
def bias_dropout_add(x_with_bias, residual, prob, training):
x, bias = x_with_bias # unpack
x = x + bias
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
torch._C._jit_set_nvfuser_enabled(True)
12. Do not issue stream-synchronizing operators during model execution
Avoid operators that force stream synchronization, as they block kernel dispatch and reduce concurrency.
13. Prefer Tensor Cores over CUDA cores
Rewrite computations to leverage matrix operations and Tensor Cores where possible. For example, replacing high-dimensional cumsum sequences with matrix multiplications can yield orders-of-magnitude speedups even if it increases raw FLOPs.
# direct cumsum
b = a.cumsum(dim=-1)
# replace with matrix computation
a = torch.matmul(a.view(x, b, s), triu_matrix)
c = a[:, :-1, -1].cumsum(-1)
a[:, 1:, :] += c.unsqueeze(-1)
a = a.view(x, b * s)
14. Tune bucket size for collective communication
Optimal bucket size for bucketed communication depends on cluster scale and should be tuned. Smaller is not always better; inappropriate bucket sizes can severely degrade training performance.
ALLPCB