pytorch制作自己的LMDB数据操作示例

本文实例讲述了pytorch制作自己的LMDB数据操作。分享给大家供大家参考,具体如下:

前言

记录下pytorch里如何使用lmdb的code,自用

制作部分的Code

code就是ASTER里数据制作部分的代码改了点,aster_train.txt里面就算图片的完整路径每行一个,图片同目录下有同名的txt,里面记着jpg的标签

import os

import lmdb # install lmdb by "pip install lmdb"

import cv2

import numpy as np

from tqdm import tqdm

import six

from PIL import Image

import scipy.io as sio

from tqdm import tqdm

import re

def checkImageIsValid(imageBin):

if imageBin is None:

return False

imageBuf = np.fromstring(imageBin, dtype=np.uint8)

img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)

imgH, imgW = img.shape[0], img.shape[1]

if imgH * imgW == 0:

return False

return True

def writeCache(env, cache):

with env.begin(write=True) as txn:

for k, v in cache.items():

txn.put(k.encode(), v)

def _is_difficult(word):

assert isinstance(word, str)

return not re.match('^[\w]+$', word)

def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):

"""

Create LMDB dataset for CRNN training.

ARGS:

outputPath : LMDB output path

imagePathList : list of image path

labelList : list of corresponding groundtruth texts

lexiconList : (optional) list of lexicon lists

checkValid : if true, check the validity of every image

"""

assert(len(imagePathList) == len(labelList))

nSamples = len(imagePathList)

env = lmdb.open(outputPath, map_size=1099511627776)#最大空间1048576GB

cache = {}

cnt = 1

for i in range(nSamples):

imagePath = imagePathList[i]

label = labelList[i]

if len(label) == 0:

continue

if not os.path.exists(imagePath):

print('%s does not exist' % imagePath)

continue

with open(imagePath, 'rb') as f:

imageBin = f.read()

if checkValid:

if not checkImageIsValid(imageBin):

print('%s is not a valid image' % imagePath)

continue

#数据库中都是二进制数据

imageKey = 'image-%09d' % cnt#9位数不足填零

labelKey = 'label-%09d' % cnt

cache[imageKey] = imageBin

cache[labelKey] = label.encode()

if lexiconList:

lexiconKey = 'lexicon-%09d' % cnt

cache[lexiconKey] = ' '.join(lexiconList[i])

if cnt % 1000 == 0:

writeCache(env, cache)

cache = {}

print('Written %d / %d' % (cnt, nSamples))

cnt += 1

nSamples = cnt-1

cache['num-samples'] = str(nSamples).encode()

writeCache(env, cache)

print('Created dataset with %d samples' % nSamples)

def get_sample_list(txt_path:str):

with open(txt_path,'r') as fr:

jpg_list=[x.strip() for x in fr.readlines() if os.path.exists(x.replace('.jpg','.txt').strip())]

txt_content_list=[]

for jpg in jpg_list:

label_path=jpg.replace('.jpg','.txt')

with open(label_path,'r') as fr:

try:

str_tmp=fr.readline()

except UnicodeDecodeError as e:

print(label_path)

raise(e)

txt_content_list.append(str_tmp.strip())

return jpg_list,txt_content_list

if __name__ == "__main__":

txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt'

lmdb_output_path = '/home/gpu-server/project/aster/dataset/train'

imagePathList,labelList=get_sample_list(txt_path)

createDataset(lmdb_output_path, imagePathList, labelList)

读取部分

这里用的pytorch的dataloader,简单记录一下,人比较懒,代码就直接抄过来,不整理拆分了,重点看__getitem__

from __future__ import absolute_import

# import sys

# sys.path.append('./')

import os

# import moxing as mox

import pickle

from tqdm import tqdm

from PIL import Image, ImageFile

import numpy as np

import random

import cv2

import lmdb

import sys

import six

import torch

from torch.utils import data

from torch.utils.data import sampler

from torchvision import transforms

from lib.utils.labelmaps import get_vocabulary, labels2strs

from lib.utils import to_numpy

ImageFile.LOAD_TRUNCATED_IMAGES = True

from config import get_args

global_args = get_args(sys.argv[1:])

if global_args.run_on_remote:

import moxing as mox

#moxing是一个分布式的框架 跳过

class LmdbDataset(data.Dataset):

def __init__(self, root, voc_type, max_len, num_samples, transform=None):

super(LmdbDataset, self).__init__()

if global_args.run_on_remote:

dataset_name = os.path.basename(root)

data_cache_url = "/cache/%s" % dataset_name

if not os.path.exists(data_cache_url):

os.makedirs(data_cache_url)

if mox.file.exists(root):

mox.file.copy_parallel(root, data_cache_url)

else:

raise ValueError("%s not exists!" % root)

self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True)

else:

self.env = lmdb.open(root, max_readers=32, readonly=True)

assert self.env is not None, "cannot create lmdb from %s" % root

self.txn = self.env.begin()

self.voc_type = voc_type

self.transform = transform

self.max_len = max_len

self.nSamples = int(self.txn.get(b"num-samples"))

self.nSamples = min(self.nSamples, num_samples)

assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS','DIGITS']

self.EOS = 'EOS'

self.PADDING = 'PADDING'

self.UNKNOWN = 'UNKNOWN'

self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN)

self.char2id = dict(zip(self.voc, range(len(self.voc))))

self.id2char = dict(zip(range(len(self.voc)), self.voc))

self.rec_num_classes = len(self.voc)

self.lowercase = (voc_type == 'LOWERCASE')

def __len__(self):

return self.nSamples

def __getitem__(self, index):

assert index <= len(self), 'index range error'

index += 1

img_key = b'image-%09d' % index

imgbuf = self.txn.get(img_key)

#由于Image.open需要一个类文件对象 所以这里需要把二进制转为一个类文件对象

buf = six.BytesIO()

buf.write(imgbuf)

buf.seek(0)

try:

img = Image.open(buf).convert('RGB')

# img = Image.open(buf).convert('L')

# img = img.convert('RGB')

except IOError:

print('Corrupted image for %d' % index)

return self[index + 1]

# reconition labels

label_key = b'label-%09d' % index

word = self.txn.get(label_key).decode()

if self.lowercase:

word = word.lower()

## fill with the padding token

label = np.full((self.max_len,), self.char2id[self.PADDING], dtype=np.int)

label_list = []

for char in word:

if char in self.char2id:

label_list.append(self.char2id[char])

else:

## add the unknown token

print('{0} is out of vocabulary.'.format(char))

label_list.append(self.char2id[self.UNKNOWN])

## add a stop token

label_list = label_list + [self.char2id[self.EOS]]

assert len(label_list) <= self.max_len

label[:len(label_list)] = np.array(label_list)

if len(label) <= 0:

return self[index + 1]

# label length

label_len = len(label_list)

if self.transform is not None:

img = self.transform(img)

return img, label, label_len

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

以上是 pytorch制作自己的LMDB数据操作示例 的全部内容, 来源链接: utcz.com/z/353773.html

回到顶部