PyTorch模型导出为ONNX格式对比

Edward826 +0/-0 0 0 正常 2025-12-24T07:01:19 PyTorch · 模型部署 · ONNX

PyTorch模型导出为ONNX格式对比

在实际部署场景中,将PyTorch模型导出为ONNX格式是提升模型兼容性和性能的关键步骤。本文通过具体代码示例对比不同导出策略的性能差异。

实验环境

  • PyTorch 2.0
  • Python 3.9
  • NVIDIA RTX 3090 GPU

模型构建与训练

import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.fc = nn.Linear(128 * 8 * 8, 10)
        
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

model = SimpleCNN()
model.eval()
input_tensor = torch.randn(1, 3, 32, 32)

导出方法对比

方法一:基础导出

torch.onnx.export(
    model,
    input_tensor,
    "basic_export.onnx",
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output']
)

方法二:动态维度导出

torch.onnx.export(
    model,
    input_tensor,
    "dynamic_export.onnx",
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}}
)

性能测试结果

使用ONNX Runtime进行推理测试,batch_size=1时:

  • 基础导出:平均延迟 2.3ms
  • 动态维度导出:平均延迟 2.1ms

在batch_size=32时:

  • 基础导出:平均延迟 6.8ms
  • 动态维度导出:平均延迟 6.2ms

动态维度导出在推理性能上略有提升,且支持不同batch_size的灵活调用。

建议:生产环境优先选择动态维度导出方案。

推广
广告位招租

讨论

0/2000
Piper667
Piper667 · 2026-01-08T10:24:58
基础导出虽然简单,但在实际部署中容易因固定输入维度导致兼容性问题,建议在模型设计初期就考虑动态输入支持,尤其针对batch size和图像尺寸可变的场景。
Arthur118
Arthur118 · 2026-01-08T10:24:58
动态维度导出能显著提升模型复用性,但需注意ONNX Runtime推理时对动态shape的支持程度,建议结合具体推理引擎做性能测试,避免因shape推断增加额外开销。
红尘紫陌
红尘紫陌 · 2026-01-08T10:24:58
在导出过程中开启`do_constant_folding=True`有助于优化模型结构,但某些自定义模块可能因fold失败导致输出异常,建议结合`verbose=True`调试,确保导出质量