如何调用 torch.onnx.export 导出的模型?

https://pytorch.org/docs/stable/onnx.html?highlight=torch%20o...

如何调用 torch.onnx.export 导出的模型?

这里有一份 torch.onnx.export 的 demo 代码,运行后保存一个 onnx.pd 文件

import torch

class SumModule(torch.nn.Module):

def forward(self, x):

return torch.sum(x, dim=1)

torch.onnx.export(

SumModule(),

(torch.ones(2, 2),),

"onnx.pb",

input_names=["x"],

output_names=["sum"]

)

我该如何使用这个 onnx.pb 文件呢?

我写了下面的 demo 代码:

python">import onnxruntime

import numpy

import torch

resnet_onnx = onnxruntime.InferenceSession(

'onnx.pb')

x = torch.ones(2, 2)

inputs = {resnet_onnx.get_inputs()[0].name: x}

resnet_onnx.run(None, inputs)

然而报错了:

─➤  python -u "/Users/ponponon/Desktop/code/me/onnx_example/use.py"                                 1 ↵

Traceback (most recent call last):

File "/Users/ponponon/Desktop/code/me/onnx_example/use.py", line 12, in <module>

resnet_onnx.run(None, inputs)

File "/Users/ponponon/.local/share/virtualenvs/onnx_example-mTitCDnh/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 200, in run

return self._sess.run(output_names, input_feed, run_options)

RuntimeError: Input must be a list of dictionaries or a single numpy array for input 'x'.

我该如何解决呢?


回答:

问题已解决:PyTorch模型的输入为tensor,而ONNX的输入为array

转成 onnx 之后,输入类型不应该是 torch.Tensor,而是 numpy.ndarray

所以,把调用的 demo 修改为如下就行调通:

import onnxruntime

import numpy

import torch

resnet_onnx = onnxruntime.InferenceSession(

'onnx.pb')

# x = torch.ones(2, 2) 输入的类型不应该是 torch.Tensor

x = numpy.ones((2, 2), dtype=numpy.float32) # 而是 numpy.ndarray

inputs = {resnet_onnx.get_inputs()[0].name: x}

print(resnet_onnx.run(None, inputs))


解决方案出处:

如何调用 torch.onnx.export 导出的模型?

9.1.3 使用ONNX Runtime进行推理


下面的 chatGPT 给的回答:

你可以使用 ONNX Runtime 来加载和执行保存在 onnx.pb 文件中的模型。

以下是一个使用 ONNX Runtime 加载并运行模型的示例代码:

import onnxruntime as ort

import numpy as np

# 加载模型

sess = ort.InferenceSession("onnx.pb")

# 准备输入数据

input_data = np.ones((2, 2)).astype("float32")

# 运行模型

output_data = sess.run(None, {"x": input_data})[0]

# 输出模型的预测结果

print(output_data)

在这个示例代码中,首先通过 ort.InferenceSession 加载 onnx.pb 文件中的模型,并使用 np.ones 函数生成了一个输入数据的示例,然后使用 sess.run 方法运行模型,并将输入数据传递给模型。最后,通过 print 输出模型的预测结果。

你可以根据自己的需要修改输入数据的形状和数值,以及输出结果的处理方式。

以上是 如何调用 torch.onnx.export 导出的模型? 的全部内容, 来源链接: utcz.com/p/938768.html

回到顶部