分布式训练中Checkpoint保存策略优化踩坑记录
最近在做分布式大模型训练时,遇到了一个让人头疼的问题:Checkpoint保存时间过长,严重影响训练效率。经过一周的排查和优化,终于找到了有效的解决方案。
问题现象
在使用PyTorch Lightning进行分布式训练时,发现每次save_checkpoint耗时高达15-20秒,原本30分钟一次的checkpoint频率变成了300分钟一次,训练进度严重滞后。
踩坑过程
起初怀疑是存储性能问题,但排查后发现是代码层面的问题。通过profile工具定位,问题出在以下几点:
- 默认的save_top_k=-1 - 保存了所有checkpoint,导致累积文件越来越多
- 未设置async_save=True - 同步保存阻塞训练进程
- 存储路径为NFS挂载 - 网络IO瓶颈明显
解决方案
# 优化后的配置
trainer = Trainer(
callbacks=[
ModelCheckpoint(
dirpath='local_path', # 改用本地SSD
filename='model-{epoch}-{step}',
save_top_k=3, # 限制保存数量
save_last=True,
every_n_train_steps=1000, # 控制频率
enable_version_counter=False,
sync_dist=False, # 关闭同步
)
],
# 关键优化项
accelerator='gpu',
strategy='ddp',
num_sanity_val_steps=0,
)
效果对比
- 优化前:单次checkpoint耗时20s
- 优化后:单次checkpoint耗时3s
实践建议
- 优先使用本地存储而非网络存储
- 合理设置save_top_k参数
- 考虑异步保存机制
- 定期清理过期checkpoint
对于大规模分布式训练,这些优化能显著提升训练效率。
可复现步骤:
- 使用分布式训练环境
- 配置默认checkpoint策略
- 运行训练并记录耗时
- 应用上述优化后再次测试

讨论