深度学习推理优化技巧:从静态图到动态图推理转换
在大模型推理阶段,性能优化一直是工程师们关注的重点。最近在项目中尝试将PyTorch模型从静态图转换为动态图推理,踩了不少坑,分享一下经验。
问题背景
原本使用torch.jit.script进行静态图编译,但发现对于动态输入shape的模型效果不佳。决定尝试torch.export + torch.compile来实现更灵活的动态图推理。
踩坑过程
第一步:简单转换失败
# 原始代码
import torch
model = MyModel()
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
使用trace方式后,模型在batch size变化时出现维度错误。
第二步:尝试torch.export
# 错误尝试
import torch
model = MyModel()
example_input = torch.randn(1, 3, 224, 224)
exported = torch.export.export(model, (example_input,))
# 这里没有考虑动态维度设置
导出后依然无法处理不同batch size的输入。
第三步:正确配置动态 shapes
import torch
model = MyModel()
example_input = torch.randn(1, 3, 224, 224)
# 正确做法
dynamic_shapes = {
"input": {0: "batch_size"}
}
exported = torch.export.export(
model,
args=(example_input,),
dynamic_shapes=dynamic_shapes
)
第四步:部署阶段的优化 使用torch.compile进一步优化:
# 将导出模型转换为torch.compile形式
optimized_model = torch.compile(exported)
总结
- 静态图trace适用于输入shape固定的情况
- 动态图export需要正确设置dynamic_shapes
- 最终结合torch.compile能获得更好的性能
这个过程确实浪费了不少时间,但对大模型推理优化很有帮助!

讨论