最近在使用PyTorch进行分布式训练时,被一个诡异的性能问题困扰了整整一周。今天必须来踩个坑。
问题背景:在5节点集群上训练一个大型Transformer模型,理论峰值吞吐量应该达到2000样本/秒,但实际只能跑到800样本/秒,怀疑是分布式通信瓶颈。
排查过程:
- 首先用
torch.profiler.profile进行基础profiling,发现GPU利用率高达95%,但CPU利用率只有30%。 - 通过
nvidia-smi观察到GPU显存使用正常,但网络带宽占用率异常。 - 使用
torch.distributed.launch的--log_level DEBUG参数启动后,发现nccl通信时间占比高达70%。
关键发现:在torchrun --nproc_per_node=4 --nnodes=5启动时,添加了环境变量NCCL_IB_DISABLE=0和NCCL_SOCKET_IFNAME=eth0,但性能提升不明显。
最终解决方案(超实用):
- 使用
torchrun --nproc_per_node=8 --nnodes=5 --master_port=12345 - 在启动脚本中添加:
export NCCL_IB_DISABLE=0
export NCCL_SOCKET_IFNAME=eth0
export NCCL_NET=IB
export NCCL_IB_HCA=mlx5_0
- 通过
watch -n 1 nvidia-smi pmon -c 1监控到GPU利用率稳定在95%,网络带宽使用率恢复正常。
关键点:必须确保所有节点的网卡驱动和内核参数一致,否则会因为通信协议不匹配导致性能骤降。
建议大家在分布式训练前,先用上述工具进行基础profiling再动手调参。

讨论