如何将下面的 resnet50 模型导出为 onnx 格式呢?
如何将下面的 resnet50 模型导出为 onnx 格式呢?batch_size 要是动态值
import torchimport torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import torchvision.models as models
def gem(x: Tensor, p: int = 3, eps: float = 1e-6) -> Tensor:
input = x.clamp(min=eps)
_input = input.pow(p)
kernel_size = (x.size(-2), x.size(-1))
t = F.avg_pool2d(_input, kernel_size).pow(1./p)
return t
def l2n(x: Tensor, eps: float = 1e-6) -> Tensor:
return x / (torch.norm(x, p=2, dim=1, keepdim=True) + eps).expand_as(x)
class L2N(nn.Module):
def __init__(self, eps=1e-6):
super(L2N, self).__init__()
self.eps = eps
def forward(self, x):
return l2n(x, eps=self.eps)
def __repr__(self):
return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')'
class GeM(nn.Module):
def __init__(self, p=3, eps=1e-6):
super(GeM, self).__init__()
self.p = Parameter(torch.ones(1)*p)
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
return gem(x, p=self.p, eps=self.eps)
class ImageRetrievalNet(nn.Module):
def __init__(self, dim: int = 512):
super(ImageRetrievalNet, self).__init__()
resnet50_model = models.resnet50()
features = list(resnet50_model.children())[:-2]
self.features = nn.Sequential(*features)
self.lwhiten = None
self.pool = GeM()
self.whiten = nn.Linear(2048, dim, bias=True)
self.norm = L2N()
def forward(self, x: Tensor):
o: Tensor = self.features(x)
# features -> pool -> norm
p = 3
eps = 1e-6
pooled_t = gem(o, p, eps)
normed_t: Tensor = self.norm(pooled_t)
o: Tensor = normed_t.squeeze(-1).squeeze(-1)
# 启用白化,则: pooled features -> whiten -> norm
if self.whiten is not None:
whitened_t = self.whiten(o)
normed_t: Tensor = self.norm(whitened_t)
o = normed_t
# 使每个图像为Dx1列向量(如果有许多图像,则为DxN)
return o.permute(1, 0)
# 创建 PyTorch ResNet50 模型实例
model = ImageRetrievalNet()
# # 定义一个 PyTorch 张量来模拟输入数据
# batch_size = 4 # 定义批处理大小
# input_shape = (batch_size, 3, 224, 224)
# input_data = torch.randn(input_shape)
# # 将模型转换为 ONNX 格式
# output_path = "resnet50.onnx"
# torch.onnx.export(
# model,
# input_data,
# output_path,
# input_names=["input"], output_names=["output"],
# dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
# )
batch_size = 4 # 定义批处理大小
input_shape = (batch_size, 3, 224, 224)
input_data = torch.randn(input_shape)
# 指定所有张量的静态形状
input_shape = ["batch_size", "channels", "height", "width"]
output_shape = ["batch_size", "features"]
torch.onnx.export(
model,
input_data,
"resnet50.onnx",
input_names=["input"], output_names=["output"],
opset_version=12,
dynamic_axes={
"input": {0: "batch_size"},
"output": {0: "batch_size"}
},
)
运行之后,报错
/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/_internal/jit_utils.py:258: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.) _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
Traceback (most recent call last):
File "/home/ponponon/code/torch_example/resnet50_export_onnx copy.py", line 123, in <module>
torch.onnx.export(
File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/utils.py", line 504, in export
_export(
File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/utils.py", line 1529, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/utils.py", line 1115, in _model_to_graph
graph = _optimize_graph(
File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/utils.py", line 663, in _optimize_graph
graph = _C._jit_pass_onnx(graph, operator_export_type)
File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/utils.py", line 1899, in _run_symbolic_function
return symbolic_fn(graph_context, *inputs, **attrs)
File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 380, in wrapper
return fn(g, *args, **kwargs)
File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 286, in wrapper
args = [
File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 287, in <listcomp>
_parse_arg(arg, arg_desc, arg_name, fn_name) # type: ignore[assignment]
File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 104, in _parse_arg
raise errors.SymbolicValueError(
torch.onnx.errors.SymbolicValueError: Failed to export a node '%518 : Long(device=cpu) = onnx::Squeeze[axes=[0]](%517), scope: __main__.ImageRetrievalNet:: # /home/ponponon/code/torch_example/resnet50_export_onnx copy.py:12:0
' (in list node %527 : int[] = prim::ListConstruct(%518, %526), scope: __main__.ImageRetrievalNet::
) because it is not constant. Please try to make things (e.g. kernel sizes) static if possible. [Caused by the value '527 defined in (%527 : int[] = prim::ListConstruct(%518, %526), scope: __main__.ImageRetrievalNet::
)' (type 'List[int]') in the TorchScript graph. The containing node has kind 'prim::ListConstruct'.]
Inputs:
#0: 518 defined in (%518 : Long(device=cpu) = onnx::Squeeze[axes=[0]](%517), scope: __main__.ImageRetrievalNet:: # /home/ponponon/code/torch_example/resnet50_export_onnx copy.py:12:0
) (type 'Tensor')
#1: 526 defined in (%526 : Long(device=cpu) = onnx::Squeeze[axes=[0]](%525), scope: __main__.ImageRetrievalNet:: # /home/ponponon/code/torch_example/resnet50_export_onnx copy.py:12:0
) (type 'Tensor')
Outputs:
#0: 527 defined in (%527 : int[] = prim::ListConstruct(%518, %526), scope: __main__.ImageRetrievalNet::
) (type 'List[int]')
回答:
问题是由于 ONNX 导出过程中无法处理某些动态值。请试以下更改以解决此问题:从 ImageRetrievalNet 类中删除 self.lwhiten 属性,因为它没有被使用。
更改 GeM 类的 forward 方法,使其不再使用 Parameter 类型的属性 self.p,而是直接使用一个常量值。
以下是修改后的 ImageRetrievalNet 和 GeM 类:
class GeM(nn.Module): def __init__(self, p=3, eps=1e-6):
super(GeM, self).__init__()
self.p = p
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
return gem(x, p=self.p, eps=self.eps)
class ImageRetrievalNet(nn.Module):
def __init__(self, dim: int = 512):
super(ImageRetrievalNet, self).__init__()
resnet50_model = models.resnet50()
features = list(resnet50_model.children())[:-2]
self.features = nn.Sequential(*features)
self.pool = GeM()
self.whiten = nn.Linear(2048, dim, bias=True)
self.norm = L2N()
def forward(self, x: Tensor):
o: Tensor = self.features(x)
# features -> pool -> norm
pooled_t = self.pool(o)
normed_t: Tensor = self.norm(pooled_t)
o: Tensor = normed_t.squeeze(-1).squeeze(-1)
# 启用白化,则: pooled features -> whiten -> norm
if self.whiten is not None:
whitened_t = self.whiten(o)
normed_t: Tensor = self.norm(whitened_t)
o = normed_t
# 使每个图像为Dx1列向量(如果有许多图像,则为DxN)
return o.permute(1, 0)
现在,您应该可以正常导出 ONNX 模型,如下所示:
# 创建 PyTorch ResNet50 模型实例model = ImageRetrievalNet()
# 定义一个 PyTorch 张量来模拟输入数据
batch_size = 4 # 定义批处理大小
input_shape = (batch_size, 3, 224, 224)
input_data = torch.randn(input_shape)
# 将模型转换为 ONNX 格式
output_path = "resnet50.onnx"
torch.onnx.export(
model,
input_data,
output_path,
input_names=["input"], output_names=["output"],
opset_version=12,
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)
这些更改应解决在将模型导出为 ONNX 时遇到的错误。
以上是 如何将下面的 resnet50 模型导出为 onnx 格式呢? 的全部内容, 来源链接: utcz.com/p/938830.html