如何调用 torch.onnx.export 导出的模型?
https://pytorch.org/docs/stable/onnx.html?highlight=torch%20o...
这里有一份 torch.onnx.export 的 demo 代码,运行后保存一个 onnx.pd 文件
import torchclass SumModule(torch.nn.Module):
def forward(self, x):
return torch.sum(x, dim=1)
torch.onnx.export(
SumModule(),
(torch.ones(2, 2),),
"onnx.pb",
input_names=["x"],
output_names=["sum"]
)
我该如何使用这个 onnx.pb 文件呢?
我写了下面的 demo 代码:
python">import onnxruntimeimport numpy
import torch
resnet_onnx = onnxruntime.InferenceSession(
'onnx.pb')
x = torch.ones(2, 2)
inputs = {resnet_onnx.get_inputs()[0].name: x}
resnet_onnx.run(None, inputs)
然而报错了:
─➤ python -u "/Users/ponponon/Desktop/code/me/onnx_example/use.py" 1 ↵Traceback (most recent call last):
File "/Users/ponponon/Desktop/code/me/onnx_example/use.py", line 12, in <module>
resnet_onnx.run(None, inputs)
File "/Users/ponponon/.local/share/virtualenvs/onnx_example-mTitCDnh/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 200, in run
return self._sess.run(output_names, input_feed, run_options)
RuntimeError: Input must be a list of dictionaries or a single numpy array for input 'x'.
我该如何解决呢?
回答:
问题已解决:PyTorch模型的输入为tensor,而ONNX的输入为array
转成 onnx 之后,输入类型不应该是 torch.Tensor,而是 numpy.ndarray
所以,把调用的 demo 修改为如下就行调通:
import onnxruntimeimport numpy
import torch
resnet_onnx = onnxruntime.InferenceSession(
'onnx.pb')
# x = torch.ones(2, 2) 输入的类型不应该是 torch.Tensor
x = numpy.ones((2, 2), dtype=numpy.float32) # 而是 numpy.ndarray
inputs = {resnet_onnx.get_inputs()[0].name: x}
print(resnet_onnx.run(None, inputs))
解决方案出处:
9.1.3 使用ONNX Runtime进行推理
下面的 chatGPT 给的回答:
你可以使用 ONNX Runtime 来加载和执行保存在 onnx.pb 文件中的模型。
以下是一个使用 ONNX Runtime 加载并运行模型的示例代码:
import onnxruntime as ortimport numpy as np
# 加载模型
sess = ort.InferenceSession("onnx.pb")
# 准备输入数据
input_data = np.ones((2, 2)).astype("float32")
# 运行模型
output_data = sess.run(None, {"x": input_data})[0]
# 输出模型的预测结果
print(output_data)
在这个示例代码中,首先通过 ort.InferenceSession
加载 onnx.pb 文件中的模型,并使用 np.ones
函数生成了一个输入数据的示例,然后使用 sess.run
方法运行模型,并将输入数据传递给模型。最后,通过 print
输出模型的预测结果。
你可以根据自己的需要修改输入数据的形状和数值,以及输出结果的处理方式。
以上是 如何调用 torch.onnx.export 导出的模型? 的全部内容, 来源链接: utcz.com/p/938768.html