在大模型训练过程中,梯度消失(Vanishing Gradient)是一个常见但棘手的问题。本文将结合实际案例,系统性地排查和解决该问题。
问题现象
在使用Transformer架构训练70B参数模型时,损失值在训练初期快速下降后趋于平稳,甚至出现震荡。通过可视化梯度发现,大部分层的梯度接近于0,尤其在前几层表现明显。
排查步骤
-
检查初始化方式:
- 默认使用Xavier初始化,尝试更换为He初始化
import torch.nn as nn layer = nn.Linear(512, 512) nn.init.kaiming_uniform_(layer.weight, mode='fan_in', nonlinearity='relu') -
检查激活函数:
- 将ReLU替换为GELU,观察梯度变化
-
调整学习率:
- 使用学习率预热策略,避免初始阶段学习率过高
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=1000) -
梯度裁剪与归一化:
- 添加梯度裁剪防止梯度爆炸,同时保持梯度稳定性
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
解决方案
最终通过组合使用He初始化、GELU激活函数和学习率预热策略,成功缓解了梯度消失问题。在训练1000步后,各层梯度均保持稳定。
小结
梯度消失往往由多个因素叠加导致,需要从初始化、激活函数、优化器参数等维度综合排查。

讨论