在LLM训练过程中,梯度计算异常是常见的问题之一,可能导致模型收敛困难或训练不稳定。本文将介绍如何排查和解决梯度计算异常问题。
常见梯度异常类型
- 梯度爆炸:梯度值异常增大,导致参数更新过大
- 梯度消失:梯度值接近零,模型无法学习
- NaN/Inf梯度:计算过程中出现无效数值
排查步骤
1. 梯度监控
# 使用PyTorch监控梯度
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
print(f'{name}: {grad_norm}')
if grad_norm > 100: # 设置阈值
print(f'警告:{name}梯度异常')
2. 梯度裁剪
# 梯度裁剪防止爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
3. 数值检查
# 检查参数和梯度是否为NaN/Inf
for name, param in model.named_parameters():
if torch.isnan(param).any() or torch.isinf(param).any():
print(f'{name} contains NaN/Inf')
解决方案
- 调整学习率
- 使用梯度裁剪
- 检查数据预处理
- 使用混合精度训练
此方法论适用于各类大模型训练场景,可有效识别和解决训练过程中的梯度异常问题。

讨论