模型训练过程断点续跑机制
在机器学习项目中,模型训练往往需要数小时甚至数天时间,网络中断、资源不足或人为干预都可能导致训练中断。为保障训练任务的连续性,需要实现断点续跑机制。
核心监控指标
- 训练进度:记录已训练epoch数和batch数,通过
torch.save()保存模型状态 - 内存使用率:监控训练过程中GPU/CPU内存变化,避免OOM
- 学习率衰减:跟踪学习率变化曲线,确保优化器正常工作
- loss值波动:实时监控训练loss变化趋势
告警配置方案
# prometheus告警规则示例
ALERT TrainingStuck
IF increase(model_loss[5m]) == 0
FOR 10m
ANNOTATIONS {
summary = "模型训练停滞超过10分钟"
}
ALERT MemoryExceeded
IF (node_memory_used_bytes / node_memory_total_bytes) > 0.9
FOR 5m
ANNOTATIONS {
summary = "内存使用率超过90%"
}
实现步骤
- 保存检查点:每epoch结束后保存模型权重和优化器状态
- 恢复训练:启动时检测是否存在checkpoint文件,自动加载
- 资源监控:通过
nvidia-smi或nvidia-ml-py监控GPU使用率
关键代码示例
# 检查点保存
if epoch % 5 == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, f'checkpoint_epoch_{epoch}.pth')
# 恢复训练
if os.path.exists('latest_checkpoint.pth'):
checkpoint = torch.load('latest_checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
该机制确保了训练任务的可靠性,避免因意外中断导致的大量重复计算。

讨论