使用 pytorch 的时候,如何复用 DataLoader ,避免重复实例化 DataLoader?
import torchfrom torch.utils.data import DataLoader, Dataset
from math import sqrt
from typing import List, Tuple, Union
from numpy import ndarray
from PIL import Image
from torchvision import transforms
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
class PreprocessImageDataset(Dataset):
def __init__(self, images: Union[List[ndarray], Tuple[ndarray]]):
self.images = images
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
image = Image.fromarray(image)
preprocessed_image: torch.Tensor = preprocess(image)
unsqueezed_image = preprocessed_image
return unsqueezed_image
if __name__=='__main__':
# 创建一些示例数据 python -m pysvddb.cli.main dna -o . --interval -5 --device mps --batch_size=1000 -i /Volumes/MyPassport/resnet/video/sample/AGanZhengZhuan_2.mp4
data = list(range(10000000))
batch_size = 10
num_workers = 16
dataset = PreprocessImageDataset(data)
dataloader = DataLoader(dataset, batch_size=batch_size,
num_workers=num_workers)
# 在训练循环中迭代加载数据批次
for batch_data in dataloader:
batch_data
print("Batch data:", batch_data)
print("Batch data type :", type(batch_data))
print("Batch data shape:", batch_data.shape)
每来一批 data,都需要 DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
,重复创建进程池、销毁进程池
怎么复用 dataloader
回答:
import torchfrom torch.utils.data import DataLoader, Dataset
from math import sqrt
from typing import List, Tuple, Union
from numpy import ndarray
from PIL import Image
from torchvision import transforms
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
class PreprocessImageDataset(Dataset):
def __init__(self, images: Union[List[ndarray], Tuple[ndarray]]):
self.images = images
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
image = Image.fromarray(image)
preprocessed_image: torch.Tensor = preprocess(image)
unsqueezed_image = preprocessed_image
return unsqueezed_image
if __name__=='__main__':
data = list(range(10000000))
batch_size = 10
num_workers = 16
dataset = PreprocessImageDataset(data)
dataloader = DataLoader(dataset, batch_size=batch_size,
num_workers=num_workers)
for epoch in range(5):
print(f"Epoch {epoch + 1}:")
for batch_data in dataloader:
batch_data
print("Batch data:", batch_data)
print("Batch data type :", type(batch_data))
print("Batch data shape:", batch_data.shape)
以上是 使用 pytorch 的时候,如何复用 DataLoader ,避免重复实例化 DataLoader? 的全部内容, 来源链接: utcz.com/p/938983.html