梯度检查点在LLM微调中的踩坑记录
最近在做LLM微调项目时,尝试应用了梯度检查点技术来优化显存使用,结果却踩了不少坑。
背景与目标
使用LoRA微调方案,希望在有限GPU内存下完成7B模型的训练。标准配置下训练会直接OOM。
实践过程
按照官方文档配置了梯度检查点:
from transformers import Trainer, TrainingArguments
testing_args = TrainingArguments(
gradient_checkpointing=True,
gradient_checkpointing_kwargs={'use_reentrant': False}, # 关键参数
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
fp16=True,
logging_steps=10,
)
真坑来了
- 参数设置错误:最初忘记设置
use_reentrant=False,导致训练报错:RuntimeError: 'use_reentrant' is not supported with torch.compile - 性能下降:虽然显存从16GB降到8GB,但训练速度慢了约30%
可复现步骤
- 确保使用transformers>=4.35.0版本
- 配置TrainingArguments时务必添加
gradient_checkpointing_kwargs={'use_reentrant': False} - 同时开启fp16和梯度累积以平衡性能
最终效果
成功在8GB显存下完成训练,虽然速度稍慢但总算能跑起来了。
注意:梯度检查点虽能节省内存,但会牺牲部分训练效率,需根据实际情况权衡。

讨论