LLM训练时模型梯度计算异常排查
在大模型训练过程中,梯度计算异常是常见的问题之一,可能导致训练失败或模型性能下降。本文将介绍如何系统性地排查此类问题。
常见异常表现
- 梯度值变为NaN或inf
- 梯度消失(接近0)
- 梯度爆炸(数值过大)
- 训练loss不收敛或震荡
排查步骤
- 检查数据预处理:确保输入数据没有异常值
import numpy as np
# 检查输入数据的统计信息
print(f"Data mean: {np.mean(data)}")
print(f"Data std: {np.std(data)}")
print(f"NaN count: {np.sum(np.isnan(data))}")
- 梯度监控:在训练循环中添加梯度检查
for batch in dataloader:
optimizer.zero_grad()
outputs = model(batch)
loss = criterion(outputs, targets)
loss.backward()
# 梯度异常检测
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
if np.isnan(grad_norm) or np.isinf(grad_norm):
print(f"Gradient NaN/Inf detected in {name}")
# 可以在此处添加断点或日志记录
- 学习率调整:异常梯度可能需要减小学习率
# 检查学习率设置
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
预防措施
- 使用梯度裁剪(Gradient Clipping)
- 监控训练过程中的loss和梯度变化
- 定期保存检查点
此排查方法适用于各类大模型训练场景,建议安全工程师在日常测试中集成相关监控机制。

讨论