在大模型训练过程中,显存使用率异常升高是一个常见但棘手的问题。本文将结合实际案例,分享一套系统性的排查方法。
问题现象
训练过程中,显存使用率突然飙升至90%以上,甚至出现OOM(Out of Memory)错误。这种现象通常在模型训练的中后期出现,且难以复现。
排查步骤
1. 检查梯度累积
# 使用torch.cuda.memory_summary()监控显存
import torch
print(torch.cuda.memory_summary())
# 检查每个参数的梯度是否正常
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: {param.grad.norm().item()}")
2. 分析模型结构
使用PyTorch Profiler分析各层显存占用:
from torch.profiler import profile, record_function
with profile(activities=[torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3)) as prof:
output = model(input)
loss = criterion(output, target)
loss.backward()
3. 检查优化器状态
# 查看优化器缓存的显存
optimizer_state = optimizer.state_dict()
for param_group in optimizer_state['param_groups']:
print(f"Param group: {param_group['lr']}")
根本原因
常见原因包括:梯度爆炸、优化器状态累积、模型并行策略不当等。建议通过逐步注释代码的方式定位问题源头。

讨论