使用 pytorch 的时候,如何复用 DataLoader ,避免重复实例化 DataLoader?

使用 pytorch 的时候,如何复用 DataLoader ,避免重复实例化 DataLoader?

import torch

from torch.utils.data import DataLoader, Dataset

from math import sqrt

from typing import List, Tuple, Union

from numpy import ndarray

from PIL import Image

from torchvision import transforms

preprocess = transforms.Compose([

transforms.Resize(256),

transforms.CenterCrop(224),

transforms.ToTensor(),

transforms.Normalize(

mean=[0.485, 0.456, 0.406],

std=[0.229, 0.224, 0.225]

)

])

class PreprocessImageDataset(Dataset):

def __init__(self, images: Union[List[ndarray], Tuple[ndarray]]):

self.images = images

def __len__(self):

return len(self.images)

def __getitem__(self, idx):

image = self.images[idx]

image = Image.fromarray(image)

preprocessed_image: torch.Tensor = preprocess(image)

unsqueezed_image = preprocessed_image

return unsqueezed_image

if __name__=='__main__':

# 创建一些示例数据 python -m pysvddb.cli.main dna -o . --interval -5 --device mps --batch_size=1000 -i /Volumes/MyPassport/resnet/video/sample/AGanZhengZhuan_2.mp4

data = list(range(10000000))

batch_size = 10

num_workers = 16

dataset = PreprocessImageDataset(data)

dataloader = DataLoader(dataset, batch_size=batch_size,

num_workers=num_workers)

# 在训练循环中迭代加载数据批次

for batch_data in dataloader:

batch_data

print("Batch data:", batch_data)

print("Batch data type :", type(batch_data))

print("Batch data shape:", batch_data.shape)

每来一批 data,都需要 DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) ,重复创建进程池、销毁进程池

怎么复用 dataloader


回答:

import torch

from torch.utils.data import DataLoader, Dataset

from math import sqrt

from typing import List, Tuple, Union

from numpy import ndarray

from PIL import Image

from torchvision import transforms

preprocess = transforms.Compose([

transforms.Resize(256),

transforms.CenterCrop(224),

transforms.ToTensor(),

transforms.Normalize(

mean=[0.485, 0.456, 0.406],

std=[0.229, 0.224, 0.225]

)

])

class PreprocessImageDataset(Dataset):

def __init__(self, images: Union[List[ndarray], Tuple[ndarray]]):

self.images = images

def __len__(self):

return len(self.images)

def __getitem__(self, idx):

image = self.images[idx]

image = Image.fromarray(image)

preprocessed_image: torch.Tensor = preprocess(image)

unsqueezed_image = preprocessed_image

return unsqueezed_image

if __name__=='__main__':

data = list(range(10000000))

batch_size = 10

num_workers = 16

dataset = PreprocessImageDataset(data)

dataloader = DataLoader(dataset, batch_size=batch_size,

num_workers=num_workers)

for epoch in range(5):

print(f"Epoch {epoch + 1}:")

for batch_data in dataloader:

batch_data

print("Batch data:", batch_data)

print("Batch data type :", type(batch_data))

print("Batch data shape:", batch_data.shape)

以上是 使用 pytorch 的时候,如何复用 DataLoader ,避免重复实例化 DataLoader? 的全部内容, 来源链接: utcz.com/p/938983.html

回到顶部