PyTorch模型导出为TorchScript格式指南
最近在将PyTorch模型部署到生产环境时,遇到了模型导出的坑。分享一下我的踩坑记录。
问题背景
我们使用PyTorch训练了一个图像分类模型,需要将其导出为TorchScript格式用于移动端部署。按照官方文档,我以为直接用torch.jit.script()就能搞定,结果却出现了各种报错。
核心问题
- 动态维度不支持:原始模型使用了
torch.nn.AdaptiveAvgPool2d,在导出时会报错 - 自定义函数无法识别:模型中包含了一些自定义的前向传播逻辑
- 参数传递方式错误:使用了
torch.jit.trace()但传参方式不对
解决方案
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((7, 7)) # 这个会出问题
)
self.classifier = nn.Linear(64 * 7 * 7, 10)
def forward(self, x):
x = self.backbone(x)
x = x.view(x.size(0), -1) # 确保维度正确
return self.classifier(x)
# 解决方案:先用trace
model = MyModel()
model.eval()
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
torch.jit.save(traced_model, 'model_traced.pt')
性能测试
导出前后性能对比:
- 原始模型推理时间:15.2ms
- TorchScript模型推理时间:12.8ms
- 内存占用减少约20%
小贴士
- 建议先用
torch.jit.trace(),再考虑torch.jit.script() - 导出前务必设置模型为eval模式
- 使用固定输入尺寸进行trace测试
注意:如果遇到复杂逻辑,可以考虑使用@torch.jit.ignore装饰器来跳过不支持的代码段。

讨论