From 6a20372fc5df98e3c2bdadacd5d1f102c6311f3e Mon Sep 17 00:00:00 2001 From: wogong Date: Wed, 30 Oct 2019 09:52:01 +0800 Subject: [PATCH] update GTSRB dataloader, using ROI information. --- datasets/gtsrb.py | 35 +++++++++--- datasets/gtsrb_prepare.py | 117 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 7 deletions(-) create mode 100644 datasets/gtsrb_prepare.py diff --git a/datasets/gtsrb.py b/datasets/gtsrb.py index 36c8ac0..b5003c4 100644 --- a/datasets/gtsrb.py +++ b/datasets/gtsrb.py @@ -1,4 +1,4 @@ -"""Dataset setting and data loader for GTSRB. Raw format and not use roi info. +"""Dataset setting and data loader for GTSRB. Pickle format and use roi info. """ import os @@ -7,6 +7,30 @@ from torchvision import datasets, transforms import torch.utils.data as data from torch.utils.data.sampler import SubsetRandomSampler import numpy as np +import pickle +from PIL import Image + +class GTSRB(data.Dataset): + def __init__(self, filepath, transform=None): + with open(filepath,'rb') as f: + self.data = pickle.load(f) + self.keys = ['images', 'labels'] + self.images = self.data[self.keys[0]] + self.labels = self.data[self.keys[1]] + self.transform = transform + self.n_data = len(self.labels) + + def __getitem__(self, index): + image, label = self.images[index], self.labels[index] + image = Image.fromarray(np.uint8(image)) + if self.transform is not None: + image = self.transform(image) + label = int(label) + return image, label + + def __len__(self): + return self.n_data + def get_gtsrb(dataset_root, batch_size, train): """Get GTSRB datasets loader.""" @@ -22,8 +46,7 @@ def get_gtsrb(dataset_root, batch_size, train): ]) # datasets and data_loader - gtsrb_dataset = datasets.ImageFolder( - os.path.join(dataset_root, 'Final_Training', 'Images'), transform=pre_process) + gtsrb_dataset = GTSRB(os.path.join(dataset_root, 'gtsrb_train.p'), transform=pre_process) dataset_size = len(gtsrb_dataset) indices = list(range(dataset_size)) @@ -37,10 +60,8 @@ def get_gtsrb(dataset_root, batch_size, train): valid_sampler = SubsetRandomSampler(val_indices) if train: - gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, - sampler=train_sampler) + gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, sampler=train_sampler) return gtsrb_dataloader_train else: - gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, - sampler=valid_sampler) + gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, sampler=valid_sampler) return gtsrb_dataloader_test \ No newline at end of file diff --git a/datasets/gtsrb_prepare.py b/datasets/gtsrb_prepare.py new file mode 100644 index 0000000..2059c8f --- /dev/null +++ b/datasets/gtsrb_prepare.py @@ -0,0 +1,117 @@ +"""modified from https://github.com/haeusser/learning_by_association/blob/master/semisup/tools/gtsrb.py, thanks @haeusser""" + +from __future__ import division +from __future__ import print_function + +import csv +import pickle + +import matplotlib.pyplot as plt +from PIL import Image +import numpy as np + +DATADIR = '/home/wogong/datasets/gtsrb' + +NUM_LABELS = 43 +IMAGE_SHAPE = [40, 40, 3] + + +def get_data(name): + """Utility for convenient data loading.""" + if name in ['train', 'unlabeled']: + return read_gtsrb_pickle(DATADIR + '/gtsrb_train.p') + elif name == 'test': + return read_gtsrb_pickle(DATADIR + '/gtsrb_test.p') + + +def read_gtsrb_pickle(filename): + """ + Extract images from pickle file. + :param filename: + :return: + """ + with open(filename, mode='rb') as f: + data = pickle.load(f) + if not type(data['labels'][0]) == int: + labels = [int(x) for x in data['labels']] + else: + labels = data['labels'] + return np.array(data['images']), np.array(labels) + + +def preprocess_gtsrb(images, roi_boxes, resize_to): + """ + Crops images to region-of-interest boxes and applies resizing with bilinear + interpolation. + :param images: np.array of images + :param roi_boxes: np.array of region-of-interest boxes of the form + (left, upper, right, lower) + :return: + """ + preprocessed_images = [] + for idx, img in enumerate(images): + pil_img = Image.fromarray(img) + cropped_pil_img = pil_img.crop(roi_boxes[idx]) + resized_pil_img = cropped_pil_img.resize(resize_to, Image.BILINEAR) + preprocessed_images.append(np.asarray(resized_pil_img)) + + return np.asarray(preprocessed_images) + + +def load_and_append_image_class(prefix, gtFile, images, labels, roi_boxes): + gtReader = csv.reader(gtFile, delimiter=';') # csv parser for annotations file + next(gtReader) # skip header + # loop over all images in current annotations file + for row in gtReader: + images.append( + plt.imread(prefix + row[0])) # the 1st column is the filename + roi_boxes.append( + (float(row[3]), float(row[4]), float(row[5]), float(row[6]))) + labels.append(row[7]) # the 8th column is the label + gtFile.close() + + +def preprocess_and_convert_gtsrb_to_pickle(rootpath, pickle_filename, type='train'): + """ + Reads traffic sign data for German Traffic Sign Recognition Benchmark. + When loading the test dataset, make sure to have downloaded the EXTENDED + annotaitons including the class ids. + :param rootpath: path to the traffic sign data, + for example './GTSRB/Training' + :return: list of images, list of corresponding labels + """ + images = [] # images + labels = [] # corresponding labels + roi_boxes = [] # box coordinates for ROI (left, upper, right, lower) + + if type == 'train': + # loop over all 42 classes + for c in range(0, NUM_LABELS): + prefix = rootpath + '/' + format(c, '05d') + '/' # subdir for class + gtFile = open( + prefix + 'GT-' + format(c, '05d') + '.csv') # annotations file + + load_and_append_image_class(prefix, gtFile, images, labels, + roi_boxes) + elif type == 'test': + prefix = rootpath + '/' + gtFile = open(prefix + 'GT-final_test' + '.csv') # annotations file + load_and_append_image_class(prefix, gtFile, images, labels, roi_boxes) + else: + raise ValueError( + 'The data partition type you have provided is not valid.') + + images = np.asarray(images) + labels = np.asarray(labels) + roi_boxes = np.asarray(roi_boxes) + + preprocessed_images = preprocess_gtsrb(images, roi_boxes, resize_to=IMAGE_SHAPE[:-1]) + + pickle.dump({'images': preprocessed_images, 'labels': labels}, + open(pickle_filename, "wb")) + + +if __name__ == '__main__': + rootpath = DATADIR + '/Final_Training/Images' + pickle_filename = '/home/wogong/datasets/gtsrb/gtsrb_train.p' + preprocess_and_convert_gtsrb_to_pickle(rootpath, pickle_filename, type='train') \ No newline at end of file