大规模模型训练中的数据预处理加速踩坑记
最近在做大规模模型训练时,发现数据预处理成了性能瓶颈。本来以为只是简单的读取和转换,结果调优过程一波三折。
初始问题
使用PyTorch DataLoader加载ImageNet数据集时,单卡epoch耗时超过30分钟,明显高于预期。通过torch.utils.data.DataLoader的prefetch_factor参数设置为2后,情况有所改善但依然不够理想。
优化尝试1:多进程数据加载
from torch.utils.data import DataLoader
# 原始设置
loader = DataLoader(dataset, batch_size=256, num_workers=0)
# 优化后
loader = DataLoader(dataset, batch_size=256, num_workers=8, pin_memory=True, prefetch_factor=2)
结果:性能提升约35%,但仍不满足需求。
优化尝试2:混合精度预处理
使用torchvision.transforms的组合,并结合了PIL和numpy转换的性能差异。在数据加载阶段就将图片转为tensor格式,避免后续频繁转换。
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(), # 提前转tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
优化尝试3:使用tf.data进行预处理(踩坑重点)
最初想用tf.data加速预处理,结果发现与PyTorch训练流程整合复杂,最终放弃。但这个过程让我意识到数据管道的统一性很重要。
最终方案:自定义Dataset + 缓存策略
import torch
from torch.utils.data import Dataset
class CachedDataset(Dataset):
def __init__(self, data_list):
self.data_list = data_list
self.cache = {}
def __getitem__(self, idx):
if idx in self.cache:
return self.cache[idx]
# 预处理逻辑
item = self.process_item(self.data_list[idx])
self.cache[idx] = item
return item
最终效果:数据加载时间从30分钟缩短到12分钟,训练效率显著提升。
关键教训:预处理阶段的优化不能忽视,建议在实际训练前先做性能基准测试。

讨论