PyTorch模型导出性能测试:不同格式转换时间对比
在实际部署场景中,PyTorch模型的导出格式选择直接影响推理效率。本文通过实测不同导出方式的时间开销,为工程师提供决策依据。
测试环境
- Python 3.8
- PyTorch 2.0.1
- NVIDIA RTX 4090 GPU
- 模型:ResNet50(batch_size=32)
测试代码
import torch
import time
from torchvision import models
# 加载模型并设置为评估模式
model = models.resnet50(pretrained=True)
model.eval()
# 准备输入数据
input_tensor = torch.randn(32, 3, 224, 224)
# 测试不同导出方式的时间开销
def benchmark_export(model, input_tensor, export_func, name):
start_time = time.time()
exported_model = export_func(model, input_tensor)
end_time = time.time()
print(f"{name}: {end_time - start_time:.4f} 秒")
return exported_model
# ONNX导出
onnx_model = benchmark_export(
model, input_tensor,
lambda m, x: torch.onnx.export(m, x, "resnet50.onnx", export_params=True),
"ONNX"
)
# TorchScript(trace)
torchscript_trace = benchmark_export(
model, input_tensor,
lambda m, x: torch.jit.trace(m, x),
"TorchScript Trace"
)
# TorchScript(script)
torchscript_script = benchmark_export(
model, input_tensor,
lambda m, x: torch.jit.script(m),
"TorchScript Script"
)
测试结果
| 导出格式 | 平均耗时(秒) |
|---|---|
| ONNX | 0.8241 |
| TorchScript Trace | 0.3567 |
| TorchScript Script | 0.4123 |
结论
TorchScript trace方式在本测试环境下耗时最短,适合对速度要求高的场景;ONNX导出时间较长但兼容性更好。建议根据实际部署环境选择合适的导出格式。

讨论