分布式训练性能测试:PyTorch DDP在多GPU环境下的表现
最近在优化一个图像分类模型时,尝试了PyTorch的DDP(DistributedDataParallel)进行多GPU训练,结果发现了一些值得记录的坑。
测试环境
- 4块RTX 3090显卡
- PyTorch 2.0.1
- 数据集:CIFAR-10(batch_size=128)
代码实现
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3)
self.fc = nn.Linear(32 * 6 * 6, 10)
def forward(self, x):
x = self.conv1(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# 初始化DDP
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def train(rank, world_size):
setup(rank, world_size)
model = SimpleCNN().to(rank)
model = DDP(model, device_ids=[rank])
# 数据加载器
dataset = torch.utils.data.TensorDataset(
torch.randn(10000, 3, 32, 32),
torch.randint(0, 10, (10000,))
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
# 训练
for epoch in range(3):
for batch in dataloader:
optimizer.zero_grad()
output = model(batch[0].to(rank))
loss = criterion(output, batch[1].to(rank))
loss.backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__":
world_size = 4
torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size)
性能测试结果
| GPU数量 | 单轮训练时间 | 加速比 |
|---|---|---|
| 1 | 1.2s | 1x |
| 2 | 0.75s | 1.6x |
| 4 | 0.5s | 2.4x |
踩坑记录
- 内存泄漏:在使用DDP时,必须确保每个进程都正确调用
dist.destroy_process_group() - 梯度同步问题:发现模型未收敛时,检查发现是
optimizer.step()后没有正确同步梯度 - 数据加载瓶颈:将数据预处理移到GPU上后,性能提升明显
总结
DDP确实能有效加速训练,但需要小心处理进程间通信和内存管理。在实际项目中,建议使用torchrun启动而不是multiprocessing.spawn。

讨论