在分布式大模型训练中,训练中断恢复机制是保障计算资源不被浪费的关键环节。本文分享一个基于PyTorch Distributed的训练恢复方案。
核心思路: 使用torch.save()保存检查点,包含模型权重、优化器状态、学习率调度器以及全局训练步数。在训练中断后,通过加载该检查点继续训练。
关键代码实现:
# 保存检查点
if rank == 0:
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'epoch': epoch,
'global_step': global_step
}
torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pt')
# 恢复训练
def load_checkpoint(model, optimizer, scheduler, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
return checkpoint['epoch'], checkpoint['global_step']
# 训练循环中添加恢复逻辑
if resume_from_checkpoint:
start_epoch, start_step = load_checkpoint(model, optimizer, scheduler, checkpoint_path)
# 跳过已训练的步数
for step in range(start_step + 1, total_steps):
# 训练代码...
实战建议:
- 每隔固定epoch保存一次检查点,避免频繁IO操作
- 使用分布式文件系统如HDFS或S3存储检查点
- 在训练开始前验证检查点完整性
该方案已在多个百亿参数模型训练中稳定运行。

讨论