PyTorch模型导出为TorchScript格式指南

HighYara +0/-0 0 0 正常 2025-12-24T07:01:19 PyTorch · 性能优化

PyTorch模型导出为TorchScript格式指南

最近在将PyTorch模型部署到生产环境时,遇到了模型导出的坑。分享一下我的踩坑记录。

问题背景

我们使用PyTorch训练了一个图像分类模型,需要将其导出为TorchScript格式用于移动端部署。按照官方文档,我以为直接用torch.jit.script()就能搞定,结果却出现了各种报错。

核心问题

  1. 动态维度不支持:原始模型使用了torch.nn.AdaptiveAvgPool2d,在导出时会报错
  2. 自定义函数无法识别:模型中包含了一些自定义的前向传播逻辑
  3. 参数传递方式错误:使用了torch.jit.trace()但传参方式不对

解决方案

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((7, 7))  # 这个会出问题
        )
        self.classifier = nn.Linear(64 * 7 * 7, 10)
    
    def forward(self, x):
        x = self.backbone(x)
        x = x.view(x.size(0), -1)  # 确保维度正确
        return self.classifier(x)

# 解决方案:先用trace
model = MyModel()
model.eval()
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)

torch.jit.save(traced_model, 'model_traced.pt')

性能测试

导出前后性能对比:

  • 原始模型推理时间:15.2ms
  • TorchScript模型推理时间:12.8ms
  • 内存占用减少约20%

小贴士

  1. 建议先用torch.jit.trace(),再考虑torch.jit.script()
  2. 导出前务必设置模型为eval模式
  3. 使用固定输入尺寸进行trace测试

注意:如果遇到复杂逻辑,可以考虑使用@torch.jit.ignore装饰器来跳过不支持的代码段。

推广
广告位招租

讨论

0/2000
YoungGerald
YoungGerald · 2026-01-08T10:24:58
踩坑经验很实用!不过建议补充一点:如果模型有动态输入(如不同batch size),trace可能不够稳定,此时script+静态维度会更可靠。
Helen228
Helen228 · 2026-01-08T10:24:58
性能提升20%确实吸引人,但别忘了TorchScript导出后调试变困难了。建议导出前做好完整单元测试,确保逻辑一致。