在PyTorch分布式训练中,故障定位是性能优化的关键环节。本文将分享几种实用的故障诊断技巧。
1. 网络通信异常排查 当出现NCCL相关错误时,首先检查网络配置:
import torch.distributed as dist
import os
def init_distributed():
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
print(f"Rank {rank} initialized")
使用torchrun --nproc_per_node=4 --nnodes=2启动时,确保所有节点间网络连通性。
2. 内存泄漏检测 通过nvidia-smi监控各GPU内存使用情况,配合以下代码定位问题:
import torch
import gc
def check_memory():
for i in range(torch.cuda.device_count()):
print(f"Device {i}: {torch.cuda.memory_allocated(i)/1024**2:.1f} MB")
gc.collect()
torch.cuda.empty_cache()
3. 同步问题诊断 使用dist.barrier()进行同步测试:
# 在关键位置添加屏障
print(f"Rank {rank} before barrier")
dist.barrier()
print(f"Rank {rank} after barrier")
若某些进程卡住,说明存在死锁或通信阻塞。
4. 性能瓶颈定位 使用torch.profiler分析训练性能:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
record_shapes=True
) as prof:
# 训练代码
pass
通过profiler报告定位计算密集型操作和通信开销。

讨论