分布式训练中optimizer状态同步失败问题的解决方法
最近在进行分布式大模型训练时,遇到了一个令人头疼的问题:optimizer状态同步失败。这个问题在单机训练时完全不存在,但在多机分布式环境下就频繁报错。
问题现象
使用PyTorch DDP训练时,出现如下错误信息:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1
或者
AssertionError: optimizer state not synchronized across processes
复现步骤
- 启动多机训练脚本:
torchrun --nproc_per_node=2 train.py - 使用Adam优化器,batch size设置为64
- 模型参数初始化后立即同步optimizer状态
解决方案
经过多次踩坑,最终定位到问题根源在于optimizer state的初始化时机。正确的做法是:
# 错误方式 - 直接初始化
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 正确方式 - 等待DDP初始化完成后再初始化
model = MyModel().cuda()
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 确保所有进程都完成同步后,再进行状态同步
关键要点
- DDP初始化需要先完成模型创建
- optimizer必须在DDP包装之后再初始化
- 适当增大
--timeout参数避免超时 - 调整batch size到合适值,避免显存溢出
这个问题在大规模训练中非常常见,希望给同样踩坑的朋友们一些参考。

讨论