Panel For Example Panel For Example Panel For Example

Efficient PyTorch Programming Guide

Author : Adrian September 16, 2025

Overview

Practical tips to improve memory usage, computation/communication overlap, and kernel efficiency when developing with PyTorch.

1. Prefer all_gather_base over all_gather when possible

Using all_gather_base reduces memory fragmentation and runtime overhead compared to gathering into a Python list and concatenating.

output = torch.empty(input.numel() * world_size, dtype=input.dtype, device=input.device) torch.distributed.all_gather_base(output, input, group=xxx)

output_list = [ torch.empty(input.numel(), dtype=input.dtype, device=input.device) for _ in range(world_size) ] torch.distributed.all_gather(output_list, input, group=xxx) output = torch.cat(output_list, dim=0)

2. Use specialized operators rather than generic ones

For example, prefer F.embedding to index_select. Index-select implementations may involve host-side index expansion and memory handling that reduce efficiency.

3. Share contiguous buffers for long-lived tensors

Allocating a large contiguous buffer and slicing into it for long-lived tensors reduces fragmentation and operator dispatch overhead.

data = torch.zeros(global_size, dtype=xx, device=xx) start_idx = 0 for i in range(len(item_list)): item_list[i] = data[start_idx:start_idx + item_list[i].numel()].view(item_list[i].shape) torch.cuda.empty_cache() # free original released item list data

CUDA memory pools allocate aligned blocks; using scattered blocks increases fragmentation. Operating on large contiguous buffers also reduces the number of kernel launches and improves throughput.

4. Use asynchronous communication to increase compute/communication overlap

comm_handle = torch.distributed.all_reduce(data, group=xxx, async_op=True) # ... perform other computations ... comm_handle.wait()

With proper topology planning, intermediate computations can overlap with communication.

5. Use expandable segments for workloads with frequently changing allocation sizes

PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

This mechanism operates on the virtual address space, allowing the driver to map additional physical memory after an allocated block, which helps segments grow and can reduce fragmentation and improve cache matching.

6. Free caches at appropriate times to reduce memory footprint

torch.cuda.empty_cache()

Clear temporary device tensors created during initialization before training starts to avoid pooling fragmentation. Do not call torch.cuda.empty_cache() frequently during training unless switching tasks (e.g., train/eval), because releasing cache blocks triggers stream synchronization and can be expensive.

7. Non-blocking host-to-device copies are safe and useful

data = data.cuda(non_blocking=True)

Synchronization points are automatically inserted where subsequent operations depend on the data, while non-blocking copies allow overlap of H2D transfers with computation when there is no immediate dependency.

8. Utilize idle CPU cycles

Move suitable data preprocessing steps to the CPU to balance workload. There is often optimization potential in current code paths, such as in Megatron-LM's master branch.

9. Accelerate communication operator memory release

Optimize memory release for communication operators to reduce overhead.

10. Avoid hitting memory capacity limits during training/inference

If observed memory usage fluctuates persistently, it likely indicates hitting the memory ceiling. Frequent pool reclamation triggers stream synchronizations for each block release, causing significant performance degradation even if average utilization appears high.

11. Use NvFuser for chained elementwise operations

Annotate functions with TorchScript to enable NvFuser to generate fused kernels for forward and backward passes. All operations inside a torch.jit.script-decorated function must be compatible with TorchScript.

@torch.jit.script def bias_dropout_add(x_with_bias, residual, prob, training): x, bias = x_with_bias # unpack x = x + bias out = torch.nn.functional.dropout(x, p=prob, training=training) out = residual + out return out torch._C._jit_set_nvfuser_enabled(True)

12. Do not issue stream-synchronizing operators during model execution

Avoid operators that force stream synchronization, as they block kernel dispatch and reduce concurrency.

13. Prefer Tensor Cores over CUDA cores

Rewrite computations to leverage matrix operations and Tensor Cores where possible. For example, replacing high-dimensional cumsum sequences with matrix multiplications can yield orders-of-magnitude speedups even if it increases raw FLOPs.

# direct cumsum b = a.cumsum(dim=-1) # replace with matrix computation a = torch.matmul(a.view(x, b, s), triu_matrix) c = a[:, :-1, -1].cumsum(-1) a[:, 1:, :] += c.unsqueeze(-1) a = a.view(x, b * s)

14. Tune bucket size for collective communication

Optimal bucket size for bucketed communication depends on cluster scale and should be tuned. Smaller is not always better; inappropriate bucket sizes can severely degrade training performance.