如何替换 pytorch 的 transforms.Compose?
模型训练的时候,都是用 torchvision 的 transforms.Compose 预处理图片
例如下面这样:
preprocess = transforms.Compose([ transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
但是训练好了之后,需要部署上线了,这个时候,会把 pytorch 训练好的模型转成 onnx
这个时候,需要移除对 pytorch 的依赖
那么这个 transforms.Compose 怎么使用 numpy、PIL 等等之类的等价替代掉呢?
我写了下面的代码,但是我发现速度会比 transforms.Compose 慢 50%
不理解为什么会变慢
python">from PIL import Imageimport numpy as np
from numpy import ndarray
def preprocess(image: Image.Image) -> ndarray:
resized_image = image.resize((224, 224))
resized_image_ndarray = np.array(resized_image)
transposed_image_ndarray = resized_image_ndarray.transpose((2, 0, 1))
transposed_image_ndarrayfloat32 = transposed_image_ndarray.astype(
np.float32)
transposed_image_ndarrayfloat32 /= 255.0
mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
normalized_image_ndarray = (transposed_image_ndarrayfloat32 - mean) / std
normalized_image_ndarrayfloat32 = normalized_image_ndarray.astype(
np.float32)
return normalized_image_ndarrayfloat32
跑 3000 轮,耗时 9.727 秒
from torchvision import transformsfrom PIL import Image
from torch import Tensor
from numpy import ndarray
import numpy
preprocess = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
image = Image.open('bh.jpg')
for i in range(3000):
tensor: Tensor = preprocess(image)
跑 3000 轮,耗时 16.093 秒
from PIL import Imageimport numpy as np
from numpy import ndarray
def preprocess(image: Image.Image) -> ndarray:
resized_image = image.resize((224, 224))
resized_image_ndarray = np.array(resized_image)
transposed_image_ndarray = resized_image_ndarray.transpose((2, 0, 1))
transposed_image_ndarrayfloat32 = transposed_image_ndarray.astype(
np.float32)
transposed_image_ndarrayfloat32 /= 255.0
mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
normalized_image_ndarray = (transposed_image_ndarrayfloat32 - mean) / std
normalized_image_ndarrayfloat32 = normalized_image_ndarray.astype(
np.float32)
return normalized_image_ndarrayfloat32
image = Image.open('bh.jpg')
for i in range(3000):
preprocessed_ndarray: ndarray = preprocess(image)
回答:
原因一:
两者的 resize 逻辑不一样
具体可看:
- python 的 PIL 的 resize 的默认插值是什么?
- transforms.Resize 和 PIL 的 resize 有什么区别?
原因二:
transforms.Compose 可以利用的 CPU 在 117%
而直接是用 PIL+numpy,cpu 的最大利用率在 101%
这导致 transforms.Compose 会比 PIL+numpy 快 16%
以上是 如何替换 pytorch 的 transforms.Compose? 的全部内容, 来源链接: utcz.com/p/938839.html