大模型训练时出现NaN值问题排查和修复过程
最近在使用PyTorch训练一个7B参数的大语言模型时,训练过程中突然出现了NaN值,导致训练中断。这个问题非常棘手,因为NaN值通常会传播到后续计算中,使得整个训练过程无法继续。
问题复现步骤
- 使用HuggingFace Transformers库加载预训练模型
- 配置优化器(AdamW)和学习率调度器
- 开始训练,前几个epoch正常运行
- 第4个epoch开始出现loss为NaN
排查过程
通过调试发现,问题出现在梯度裁剪环节。当使用torch.nn.utils.clip_grad_norm_()进行梯度裁剪时,如果梯度值过大,会导致数值溢出。
修复方案:
# 修改前
optimizer.step()
# 修改后
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 检查是否出现inf或nan
if torch.isfinite(torch.tensor([p.grad for p in model.parameters() if p.grad is not None]).sum()):
optimizer.step()
预防措施
- 在训练前对输入数据进行清洗
- 合理设置梯度裁剪阈值
- 使用梯度检查点技术
这个坑踩得有点惨,希望和大家共勉!

讨论