Browse Source

update GTSRB dataloader, using ROI information.

master
wogong 5 years ago
parent
commit
6a20372fc5
  1. 35
      datasets/gtsrb.py
  2. 117
      datasets/gtsrb_prepare.py

35
datasets/gtsrb.py

@ -1,4 +1,4 @@
"""Dataset setting and data loader for GTSRB. Raw format and not use roi info.
"""Dataset setting and data loader for GTSRB. Pickle format and use roi info.
""" """
import os import os
@ -7,6 +7,30 @@ from torchvision import datasets, transforms
import torch.utils.data as data import torch.utils.data as data
from torch.utils.data.sampler import SubsetRandomSampler from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np import numpy as np
import pickle
from PIL import Image
class GTSRB(data.Dataset):
def __init__(self, filepath, transform=None):
with open(filepath,'rb') as f:
self.data = pickle.load(f)
self.keys = ['images', 'labels']
self.images = self.data[self.keys[0]]
self.labels = self.data[self.keys[1]]
self.transform = transform
self.n_data = len(self.labels)
def __getitem__(self, index):
image, label = self.images[index], self.labels[index]
image = Image.fromarray(np.uint8(image))
if self.transform is not None:
image = self.transform(image)
label = int(label)
return image, label
def __len__(self):
return self.n_data
def get_gtsrb(dataset_root, batch_size, train): def get_gtsrb(dataset_root, batch_size, train):
"""Get GTSRB datasets loader.""" """Get GTSRB datasets loader."""
@ -22,8 +46,7 @@ def get_gtsrb(dataset_root, batch_size, train):
]) ])
# datasets and data_loader # datasets and data_loader
gtsrb_dataset = datasets.ImageFolder(
os.path.join(dataset_root, 'Final_Training', 'Images'), transform=pre_process)
gtsrb_dataset = GTSRB(os.path.join(dataset_root, 'gtsrb_train.p'), transform=pre_process)
dataset_size = len(gtsrb_dataset) dataset_size = len(gtsrb_dataset)
indices = list(range(dataset_size)) indices = list(range(dataset_size))
@ -37,10 +60,8 @@ def get_gtsrb(dataset_root, batch_size, train):
valid_sampler = SubsetRandomSampler(val_indices) valid_sampler = SubsetRandomSampler(val_indices)
if train: if train:
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size,
sampler=train_sampler)
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, sampler=train_sampler)
return gtsrb_dataloader_train return gtsrb_dataloader_train
else: else:
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size,
sampler=valid_sampler)
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, sampler=valid_sampler)
return gtsrb_dataloader_test return gtsrb_dataloader_test

117
datasets/gtsrb_prepare.py

@ -0,0 +1,117 @@
"""modified from https://github.com/haeusser/learning_by_association/blob/master/semisup/tools/gtsrb.py, thanks @haeusser"""
from __future__ import division
from __future__ import print_function
import csv
import pickle
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
DATADIR = '/home/wogong/datasets/gtsrb'
NUM_LABELS = 43
IMAGE_SHAPE = [40, 40, 3]
def get_data(name):
"""Utility for convenient data loading."""
if name in ['train', 'unlabeled']:
return read_gtsrb_pickle(DATADIR + '/gtsrb_train.p')
elif name == 'test':
return read_gtsrb_pickle(DATADIR + '/gtsrb_test.p')
def read_gtsrb_pickle(filename):
"""
Extract images from pickle file.
:param filename:
:return:
"""
with open(filename, mode='rb') as f:
data = pickle.load(f)
if not type(data['labels'][0]) == int:
labels = [int(x) for x in data['labels']]
else:
labels = data['labels']
return np.array(data['images']), np.array(labels)
def preprocess_gtsrb(images, roi_boxes, resize_to):
"""
Crops images to region-of-interest boxes and applies resizing with bilinear
interpolation.
:param images: np.array of images
:param roi_boxes: np.array of region-of-interest boxes of the form
(left, upper, right, lower)
:return:
"""
preprocessed_images = []
for idx, img in enumerate(images):
pil_img = Image.fromarray(img)
cropped_pil_img = pil_img.crop(roi_boxes[idx])
resized_pil_img = cropped_pil_img.resize(resize_to, Image.BILINEAR)
preprocessed_images.append(np.asarray(resized_pil_img))
return np.asarray(preprocessed_images)
def load_and_append_image_class(prefix, gtFile, images, labels, roi_boxes):
gtReader = csv.reader(gtFile, delimiter=';') # csv parser for annotations file
next(gtReader) # skip header
# loop over all images in current annotations file
for row in gtReader:
images.append(
plt.imread(prefix + row[0])) # the 1st column is the filename
roi_boxes.append(
(float(row[3]), float(row[4]), float(row[5]), float(row[6])))
labels.append(row[7]) # the 8th column is the label
gtFile.close()
def preprocess_and_convert_gtsrb_to_pickle(rootpath, pickle_filename, type='train'):
"""
Reads traffic sign data for German Traffic Sign Recognition Benchmark.
When loading the test dataset, make sure to have downloaded the EXTENDED
annotaitons including the class ids.
:param rootpath: path to the traffic sign data,
for example './GTSRB/Training'
:return: list of images, list of corresponding labels
"""
images = [] # images
labels = [] # corresponding labels
roi_boxes = [] # box coordinates for ROI (left, upper, right, lower)
if type == 'train':
# loop over all 42 classes
for c in range(0, NUM_LABELS):
prefix = rootpath + '/' + format(c, '05d') + '/' # subdir for class
gtFile = open(
prefix + 'GT-' + format(c, '05d') + '.csv') # annotations file
load_and_append_image_class(prefix, gtFile, images, labels,
roi_boxes)
elif type == 'test':
prefix = rootpath + '/'
gtFile = open(prefix + 'GT-final_test' + '.csv') # annotations file
load_and_append_image_class(prefix, gtFile, images, labels, roi_boxes)
else:
raise ValueError(
'The data partition type you have provided is not valid.')
images = np.asarray(images)
labels = np.asarray(labels)
roi_boxes = np.asarray(roi_boxes)
preprocessed_images = preprocess_gtsrb(images, roi_boxes, resize_to=IMAGE_SHAPE[:-1])
pickle.dump({'images': preprocessed_images, 'labels': labels},
open(pickle_filename, "wb"))
if __name__ == '__main__':
rootpath = DATADIR + '/Final_Training/Images'
pickle_filename = '/home/wogong/datasets/gtsrb/gtsrb_train.p'
preprocess_and_convert_gtsrb_to_pickle(rootpath, pickle_filename, type='train')
Loading…
Cancel
Save