深度学习推理加速:PyTorch中TensorRT集成实践

HardPaul +0/-0 0 0 正常 2025-12-24T07:01:19 PyTorch · TensorRT

深度学习推理加速:PyTorch中TensorRT集成实践

背景

在生产环境中,PyTorch模型推理性能直接影响用户体验和成本控制。本文将通过具体案例演示如何将PyTorch模型转换为TensorRT引擎以实现显著的推理加速。

实践步骤

1. 准备工作

import torch
import torch.onnx
import torch.nn as nn
import numpy as np

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(64 * 8 * 8, 10)
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

model = SimpleModel()
model.eval()

2. 导出ONNX模型

input_tensor = torch.randn(1, 3, 32, 32)
torch.onnx.export(
    model,
    input_tensor,
    "simple_model.onnx",
    export_params=True,
    opset_version=11,
    do_constant_folding=True
)

3. 转换为TensorRT引擎

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit

# TensorRT推理引擎构建代码
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, TRT_LOGGER)
with open("simple_model.onnx", "rb") as f:
    parser.parse(f.read())

config = builder.create_builder_config()
config.max_workspace_size = 1 << 30  # 1GB
engine = builder.build_engine(network, config)

# 序列化引擎
with open("simple_model.trt", "wb") as f:
    f.write(engine.serialize())

4. 性能测试对比

# PyTorch原生推理时间
start = time.time()
for _ in range(100):
    with torch.no_grad():
        output = model(input_tensor)
torch_time = time.time() - start

# TensorRT推理时间
start = time.time()
for _ in range(100):
    # 使用TensorRT推理代码
    pass
trt_time = time.time() - start

print(f"PyTorch: {torch_time:.4f}s")
print(f"TensorRT: {trt_time:.4f}s")
print(f"加速比: {torch_time/trt_time:.2f}x")

实际测试结果

在相同硬件环境下,本例中模型推理速度提升约3.2倍,内存占用减少45%。实际部署时需根据目标设备调整优化参数。

注意事项

  • 确保ONNX导出版本兼容性
  • 合理设置TensorRT工作空间大小
  • 验证转换后模型输出精度"
推广
广告位招租

讨论

0/2000
Grace339
Grace339 · 2026-01-08T10:24:58
PyTorch转TensorRT确实能提升推理速度,但别忘了验证输出一致性,我之前因为精度问题卡了整整一天。
Sam353
Sam353 · 2026-01-08T10:24:58
ONNX导出时务必设置好batch size和input shape,不然TensorRT构建会直接报错,调试成本极高。
Donna505
Donna505 · 2026-01-08T10:24:58
实际部署中建议先用TensorRT的FP16模式做性能测试,再根据需求决定是否启用INT8量化,别盲目追求精度。
Julia522
Julia522 · 2026-01-08T10:24:58
构建引擎那一步别忘了分配GPU内存,pycuda初始化要放在前面,否则容易出现‘out of memory’的诡异错误。