PyTorch DDP训练测试方法

DarkHero +0/-0 0 0 正常 2025-12-24T07:01:19 PyTorch · distributed

PyTorch DDP训练测试方法

在分布式训练中,PyTorch Distributed (DDP) 是主流的多机多卡训练框架。本文将介绍一套完整的DDP训练测试方法论。

环境准备

首先确保安装了PyTorch 1.8+版本,并配置好NCCL环境:

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

基础测试脚本

创建一个基础的DDP训练脚本:

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    # 初始化分布式环境
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)
    
    # 创建模型并移动到GPU
    model = torch.nn.Linear(1000, 10).to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    
    # 创建数据集和dataloader
    dataset = torch.utils.data.TensorDataset(
        torch.randn(1000, 1000),
        torch.randint(0, 10, (1000,))
    )
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
    
    # 定义损失函数和优化器
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.01)
    
    # 训练循环
    for epoch in range(5):
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(rank), target.to(rank)
            optimizer.zero_grad()
            output = ddp_model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    
    cleanup()

if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

性能测试步骤

  1. 单机多卡测试:使用torchrun --nproc_per_node=4 train.py运行
  2. 多机测试:配置多个节点的IP地址和端口
  3. 性能指标监控:记录训练时间、GPU利用率等

优化建议

  • 调整--bucket_cap_mb参数优化梯度同步
  • 使用torch.compile()加速模型
  • 合理设置batch size避免内存溢出
推广
广告位招租

讨论

0/2000
紫色玫瑰
紫色玫瑰 · 2026-01-08T10:24:58
DDP训练确实能大幅提升多卡训练效率,但别忘了先在单卡上把模型跑通再上分布式,不然调参成本高得吓人。
柔情密语
柔情密语 · 2026-01-08T10:24:58
setup函数里用nccl初始化要确保所有节点网络连通,我之前就因为防火墙问题卡了整整一天,建议提前测试网络。
Carl450
Carl450 · 2026-01-08T10:24:58
数据加载器的shuffle参数和batch_size设置很关键,尤其在多机场景下,不注意容易出现梯度同步不一致的问题。
健身生活志
健身生活志 · 2026-01-08T10:24:58
实际项目中建议用torchrun替代mp.spawn,更稳定且支持更多分布式配置选项,尤其在Slurm环境下特别好用。