大模型训练中的训练中断恢复策略
最近在做LLM训练时遇到了一个血泪史。训练进行到第1200步时,由于节点掉电导致整个训练任务中断,当时心里一万只羊驼奔腾而过。
问题重现
# 原始训练代码片段
for step in range(10000):
# 训练逻辑...
if step % 1000 == 0:
# checkpoint保存
torch.save(model.state_dict(), f'checkpoint_{step}.pt')
结果:保存的checkpoint文件在恢复时发现tensor shape不匹配,训练中断后又重启了两次,每次都是从头开始。
解决方案
- 使用PyTorch内置的分布式检查点恢复机制:
from torch.distributed.checkpoint import save_state_dict, load_state_dict
# 保存时
save_state_dict(model.state_dict(), f'checkpoint_{step}.pt')
# 恢复时
load_state_dict(model.state_dict(), f'checkpoint_{step}.pt')
-
设置检查点间隔:建议每500-1000步保存一次,避免训练中断后损失过多数据。
-
使用wandb或tensorboard的自动恢复功能:通过日志记录当前训练状态,实现断点续训。
实践建议
- 使用
--resume参数配合训练脚本,可快速定位到最近检查点 - 避免在训练过程中频繁修改checkpoint路径和文件名格式
- 建议在训练开始前进行一次完整的检查点保存和恢复测试
目前这套方案已经跑了5轮完整训练,成功率接近100%。

讨论