Browse Source

update GTSRB dataloader, using ROI information.

master
wogong 5 years ago
parent
commit
6a20372fc5
  1. 35
      datasets/gtsrb.py
  2. 117
      datasets/gtsrb_prepare.py

35
datasets/gtsrb.py

@ -1,4 +1,4 @@
"""Dataset setting and data loader for GTSRB. Raw format and not use roi info.
"""Dataset setting and data loader for GTSRB. Pickle format and use roi info.
"""
import os
@ -7,6 +7,30 @@ from torchvision import datasets, transforms
import torch.utils.data as data
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import pickle
from PIL import Image
class GTSRB(data.Dataset):
def __init__(self, filepath, transform=None):
with open(filepath,'rb') as f:
self.data = pickle.load(f)
self.keys = ['images', 'labels']
self.images = self.data[self.keys[0]]
self.labels = self.data[self.keys[1]]
self.transform = transform
self.n_data = len(self.labels)
def __getitem__(self, index):
image, label = self.images[index], self.labels[index]
image = Image.fromarray(np.uint8(image))
if self.transform is not None:
image = self.transform(image)
label = int(label)
return image, label
def __len__(self):
return self.n_data
def get_gtsrb(dataset_root, batch_size, train):
"""Get GTSRB datasets loader."""
@ -22,8 +46,7 @@ def get_gtsrb(dataset_root, batch_size, train):
])
# datasets and data_loader
gtsrb_dataset = datasets.ImageFolder(
os.path.join(dataset_root, 'Final_Training', 'Images'), transform=pre_process)
gtsrb_dataset = GTSRB(os.path.join(dataset_root, 'gtsrb_train.p'), transform=pre_process)
dataset_size = len(gtsrb_dataset)
indices = list(range(dataset_size))
@ -37,10 +60,8 @@ def get_gtsrb(dataset_root, batch_size, train):
valid_sampler = SubsetRandomSampler(val_indices)
if train:
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size,
sampler=train_sampler)
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, sampler=train_sampler)
return gtsrb_dataloader_train
else:
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size,
sampler=valid_sampler)
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, sampler=valid_sampler)
return gtsrb_dataloader_test

117
datasets/gtsrb_prepare.py

@ -0,0 +1,117 @@
"""modified from https://github.com/haeusser/learning_by_association/blob/master/semisup/tools/gtsrb.py, thanks @haeusser"""
from __future__ import division
from __future__ import print_function
import csv
import pickle
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
DATADIR = '/home/wogong/datasets/gtsrb'
NUM_LABELS = 43
IMAGE_SHAPE = [40, 40, 3]
def get_data(name):
"""Utility for convenient data loading."""
if name in ['train', 'unlabeled']:
return read_gtsrb_pickle(DATADIR + '/gtsrb_train.p')
elif name == 'test':
return read_gtsrb_pickle(DATADIR + '/gtsrb_test.p')
def read_gtsrb_pickle(filename):
"""
Extract images from pickle file.
:param filename:
:return:
"""
with open(filename, mode='rb') as f:
data = pickle.load(f)
if not type(data['labels'][0]) == int:
labels = [int(x) for x in data['labels']]
else:
labels = data['labels']
return np.array(data['images']), np.array(labels)
def preprocess_gtsrb(images, roi_boxes, resize_to):
"""
Crops images to region-of-interest boxes and applies resizing with bilinear
interpolation.
:param images: np.array of images
:param roi_boxes: np.array of region-of-interest boxes of the form
(left, upper, right, lower)
:return:
"""
preprocessed_images = []
for idx, img in enumerate(images):
pil_img = Image.fromarray(img)
cropped_pil_img = pil_img.crop(roi_boxes[idx])
resized_pil_img = cropped_pil_img.resize(resize_to, Image.BILINEAR)
preprocessed_images.append(np.asarray(resized_pil_img))
return np.asarray(preprocessed_images)
def load_and_append_image_class(prefix, gtFile, images, labels, roi_boxes):
gtReader = csv.reader(gtFile, delimiter=';') # csv parser for annotations file
next(gtReader) # skip header
# loop over all images in current annotations file
for row in gtReader:
images.append(
plt.imread(prefix + row[0])) # the 1st column is the filename
roi_boxes.append(
(float(row[3]), float(row[4]), float(row[5]), float(row[6])))
labels.append(row[7]) # the 8th column is the label
gtFile.close()
def preprocess_and_convert_gtsrb_to_pickle(rootpath, pickle_filename, type='train'):
"""
Reads traffic sign data for German Traffic Sign Recognition Benchmark.
When loading the test dataset, make sure to have downloaded the EXTENDED
annotaitons including the class ids.
:param rootpath: path to the traffic sign data,
for example './GTSRB/Training'
:return: list of images, list of corresponding labels
"""
images = [] # images
labels = [] # corresponding labels
roi_boxes = [] # box coordinates for ROI (left, upper, right, lower)
if type == 'train':
# loop over all 42 classes
for c in range(0, NUM_LABELS):
prefix = rootpath + '/' + format(c, '05d') + '/' # subdir for class
gtFile = open(
prefix + 'GT-' + format(c, '05d') + '.csv') # annotations file
load_and_append_image_class(prefix, gtFile, images, labels,
roi_boxes)
elif type == 'test':
prefix = rootpath + '/'
gtFile = open(prefix + 'GT-final_test' + '.csv') # annotations file
load_and_append_image_class(prefix, gtFile, images, labels, roi_boxes)
else:
raise ValueError(
'The data partition type you have provided is not valid.')
images = np.asarray(images)
labels = np.asarray(labels)
roi_boxes = np.asarray(roi_boxes)
preprocessed_images = preprocess_gtsrb(images, roi_boxes, resize_to=IMAGE_SHAPE[:-1])
pickle.dump({'images': preprocessed_images, 'labels': labels},
open(pickle_filename, "wb"))
if __name__ == '__main__':
rootpath = DATADIR + '/Final_Training/Images'
pickle_filename = '/home/wogong/datasets/gtsrb/gtsrb_train.p'
preprocess_and_convert_gtsrb_to_pickle(rootpath, pickle_filename, type='train')
Loading…
Cancel
Save