模型训练过程中的梯度爆炸检测与告警机制
梯度爆炸监控指标
在模型训练过程中,需要实时监控以下关键指标:
- 梯度范数(Gradient Norm):计算所有参数梯度的L2范数,当值超过阈值100时触发预警
- 损失函数变化率:每批次损失值变化率超过50%时标记异常
- 参数更新幅度:参数更新量超过初始值10倍时异常
具体实现步骤
- 梯度监控代码:
import torch
import numpy as np
# 在训练循环中添加
for batch in dataloader:
optimizer.zero_grad()
outputs = model(batch)
loss = criterion(outputs, targets)
loss.backward()
# 梯度检查
total_norm = torch.norm(torch.stack([torch.norm(p.grad.data) for p in model.parameters() if p.grad is not None]))
if total_norm > 100:
print(f"梯度爆炸:{total_norm}")
send_alert("梯度爆炸", f"梯度范数:{total_norm}")
- 告警配置:
- 阈值设置:梯度范数>100时发送邮件告警
- 告警级别:严重级别,自动通知运维团队
- 重试机制:连续3次触发后启动降级策略
告警触发条件
当梯度范数连续5个批次超过阈值,或损失函数急剧上升时,立即触发告警并暂停训练。

讨论