超大模型训练时的模型切片与通信优化踩坑记录
最近在做LLaMA-2 70B模型的分布式训练,踩了几个关于模型切片和通信优化的坑,分享一下。
问题背景
使用PyTorch DDP + FSDP进行训练时,发现训练速度严重下降,尤其是在多机多卡场景下。
主要踩坑点
- 模型切片策略不当:最初设置
sharding_strategy='FULL_SHARD',但没有调整cpu_offload=True导致显存占用过高 - 通信优化缺失:未启用
torch.distributed.optim.Optimizer的异步通信 - 梯度裁剪配置错误:
clip_grad_norm_参数设置过小导致训练不稳定
可复现步骤
# 1. 初始化FSDP
from torch.distributed.fsdp import FSDP, ShardingStrategy
from torch.distributed.fsdp.wrap import wrap
def setup_fsdp(model):
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
cpu_offload=True, # 关键优化点
use_orig_params=True
)
return model
# 2. 启用异步通信
import torch.distributed as dist
if dist.is_available():
dist.init_process_group(backend='nccl')
torch.distributed.optim.Optimizer.register_comm_hook(
dist.ReduceOp.SUM,
lambda group, bucket: dist.all_reduce(bucket, op=dist.ReduceOp.SUM)
)
# 3. 梯度裁剪调整
for name, param in model.named_parameters():
if param.requires_grad:
torch.nn.utils.clip_grad_norm_(param, max_norm=1.0) # 调整此值
解决方案
通过调整上述参数,训练速度提升约40%,显存使用率优化30%。
总结
建议在超大模型训练中优先考虑CPU offload和异步通信,避免盲目追求全量切片。

讨论