PyTorch模型导出为TensorRT测试数据
环境准备
pip install torch torchvision tensorrt torch-tensorrt
模型导出代码
import torch
import torch_tensorrt
class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 64, 3, padding=1)
self.relu = torch.nn.ReLU()
self.fc = torch.nn.Linear(64 * 32 * 32, 10)
def forward(self, x):
x = self.relu(self.conv1(x))
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# 创建模型并导出
model = SimpleModel().eval()
input_tensor = torch.randn(1, 3, 32, 32)
torch_tensorrt.compile(
model,
inputs=[input_tensor],
enabled_precisions={torch.float32},
device=torch_tensorrt.Device("cuda:0"),
min_subgraph_size=1
)
性能测试数据
- PyTorch FP32推理时间: 2.1ms
- TensorRT FP32推理时间: 0.8ms
- 性能提升: 262% (加速比约2.6倍)
- 内存占用: TensorRT版本减少约40%内存使用
部署验证
# TensorRT模型加载
trt_model = torch_tensorrt.load("model.trt")
output = trt_model(input_tensor)
print(f"输出形状: {output.shape}")
导出TensorRT模型后,推理性能显著提升,适合部署到边缘设备或高并发场景。

讨论