pytorch ImageFolder的覆写实例

在为数据分类训练分类器的时候,比如猫狗分类时,我们经常会使用pytorch的ImageFolder:

CLASS torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)

使用可见pytorch torchvision.ImageFolder的用法介绍

这里想实现的是如果想要覆写该函数,即能使用它的特性,又可以实现自己的功能

首先先分析下其源代码:

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp']

class ImageFolder(DatasetFolder):

"""A generic data loader where the images are arranged in this way: ::

root/dog/xxx.png

root/dog/xxy.png

root/dog/xxz.png

root/cat/123.png

root/cat/nsdf3.png

root/cat/asd932_.png

Args:

root (string): Root directory path.

transform (callable, optional): A function/transform that takes in an PIL image

and returns a transformed version. E.g, ``transforms.RandomCrop``

target_transform (callable, optional): A function/transform that takes in the

target and transforms it.

loader (callable, optional): A function to load an image given its path.

Attributes:

classes (list): List of the class names.

class_to_idx (dict): Dict with items (class_name, class_index).

imgs (list): List of (image path, class_index) tuples

"""

def __init__(self, root, transform=None, target_transform=None,

loader=default_loader):

super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,

transform=transform,

target_transform=target_transform)

self.imgs = self.samples

ImageFolder的代码很简单,主要是继承了DatasetFolder:

def has_file_allowed_extension(filename, extensions):

"""查看文件是否是支持的可扩展类型

Args:

filename (string): 文件路径

extensions (iterable of strings): 可扩展类型列表,即能接受的图像文件类型

Returns:

bool: True if the filename ends with one of given extensions

"""

filename_lower = filename.lower()

return any(filename_lower.endswith(ext) for ext in extensions) # 返回True或False列表

def make_dataset(dir, class_to_idx, extensions):

"""

返回形如[(图像路径, 该图像对应的类别索引值),(),...]

"""

images = []

dir = os.path.expanduser(dir)

for target in sorted(class_to_idx.keys()):

d = os.path.join(dir, target)

if not os.path.isdir(d):

continue

for root, _, fnames in sorted(os.walk(d)): #层层遍历文件夹,返回当前文件夹路径,存在的所有文件夹名,存在的所有文件名

for fname in sorted(fnames):

if has_file_allowed_extension(fname, extensions):查看文件是否是支持的可扩展类型,是则继续

path = os.path.join(root, fname)

item = (path, class_to_idx[target])

images.append(item)

return images

class DatasetFolder(data.Dataset):

"""A generic data loader where the samples are arranged in this way: ::

root/class_x/xxx.ext

root/class_x/xxy.ext

root/class_x/xxz.ext

root/class_y/123.ext

root/class_y/nsdf3.ext

root/class_y/asd932_.ext

Args:

root (string): 根目录路径

loader (callable): 根据给定的路径来加载样本的可调用函数

extensions (list[string]): 可扩展类型列表,即能接受的图像文件类型.

transform (callable, optional): 用于样本的transform函数,然后返回样本transform后的版本

E.g, ``transforms.RandomCrop`` for images.

target_transform (callable, optional): 用于样本标签的transform函数

Attributes:

classes (list): 类别名列表

class_to_idx (dict): 项目(class_name, class_index)字典,如{'cat': 0, 'dog': 1}

samples (list): (sample path, class_index) 元组列表,即(样本路径, 类别索引)

targets (list): 在数据集中每张图片的类索引值,为列表

"""

def __init__(self, root, loader, extensions, transform=None, target_transform=None):

classes, class_to_idx = self._find_classes(root) # 得到类名和类索引,如['cat', 'dog']和{'cat': 0, 'dog': 1}

# 返回形如[(图像路径, 该图像对应的类别索引值),(),...],即对每个图像进行标记

samples = make_dataset(root, class_to_idx, extensions)

if len(samples) == 0:

raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"

"Supported extensions are: " + ",".join(extensions)))

self.root = root

self.loader = loader

self.extensions = extensions

self.classes = classes

self.class_to_idx = class_to_idx

self.samples = samples

self.targets = [s[1] for s in samples] #所有图像的类索引值组成的列表

self.transform = transform

self.target_transform = target_transform

def _find_classes(self, dir):

"""

在数据集中查找类文件夹。

Args:

dir (string): 根目录路径

Returns:

返回元组: (classes, class_to_idx)即(类名, 类索引),其中classes即相应的目录名,如['cat', 'dog'];class_to_idx为形如{类名:类索引}的字典,如{'cat': 0, 'dog': 1}.

Ensures:

保证没有类名是另一个类目录的子目录

"""

if sys.version_info >= (3, 5):

# Faster and available in Python 3.5 and above

classes = [d.name for d in os.scandir(dir) if d.is_dir()] #获得根目录dir的所有第一层子目录名

else:

classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] #效果和上面的一样,只是版本不同方法不同

classes.sort() #然后对类名进行排序

class_to_idx = {classes[i]: i for i in range(len(classes))} #然后将类名和索引值一一对应的到相应字典,如{'cat': 0, 'dog': 1}

return classes, class_to_idx #然后返回类名和类索引

def __getitem__(self, index):

"""

Args:

index (int): Index

Returns:

tuple: (sample, target) where target is class_index of the target class.

"""

path, target = self.samples[index]

sample = self.loader(path) # 加载图片

if self.transform is not None:

sample = self.transform(sample)

if self.target_transform is not None:

target = self.target_transform(target)

return sample, target

def __len__(self):

return len(self.samples)

def __repr__(self):

fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'

fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())

fmt_str += ' Root Location: {}\n'.format(self.root)

tmp = ' Transforms (if any): '

fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))

tmp = ' Target Transforms (if any): '

fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))

return fmt_str

此时想要覆写ImageFolder,代码为:

class CustomImageFolder(ImageFolder):

"""

为了得到两张图(其中一张是随机选取的)的图像和索引值信息

"""

def __init__(self, root, transform=None):

super(CustomImageFolder, self).__init__(root, transform)

self.indices = range(len(self)) #该文件夹中的长度

def __getitem__(self, index1):

index2 = random.choice(self.indices) #从[0,indices]中随机抽取一个数字,为了随机选取一张图

path1 = self.imgs[index1][0] #此时的self.imgs等于self.samples,即内容为[(图像路径, 该图像对应的类别索引值),(),...]

label1 = self.imgs[index1][1]

path2 = self.imgs[index2][0]

label2 = self.imgs[index2][1]

img1 = self.loader(path1)

img2 = self.loader(path2)

if self.transform is not None:

img1 = self.transform(img1)

img2 = self.transform(img2)

return img1, img2, label1, label2

以上这篇pytorch ImageFolder的覆写实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。

以上是 pytorch ImageFolder的覆写实例 的全部内容, 来源链接: utcz.com/z/339195.html

回到顶部