From aa7fcfc59dfce4cee81037e23e043d18f8c43d3a Mon Sep 17 00:00:00 2001 From: wogong Date: Wed, 30 Oct 2019 20:56:34 +0800 Subject: [PATCH] add legacy GTSRB dataloader for comparison. --- datasets/gtsrb_legacy.py | 45 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 datasets/gtsrb_legacy.py diff --git a/datasets/gtsrb_legacy.py b/datasets/gtsrb_legacy.py new file mode 100644 index 0000000..92fa9ba --- /dev/null +++ b/datasets/gtsrb_legacy.py @@ -0,0 +1,45 @@ +"""Dataset setting and data loader for GTSRB.""" + +import os +import torch +from torchvision import datasets, transforms +import torch.utils.data as data +from torch.utils.data.sampler import SubsetRandomSampler +import numpy as np + +def get_gtsrb(dataset_root, batch_size, train): + """Get GTSRB datasets loader.""" + shuffle_dataset = True + random_seed = 42 + train_size = 31367 + + # image pre-processing + pre_process = transforms.Compose([ + transforms.Resize((40, 40)), + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ]) + + # datasets and data_loader + gtsrb_dataset = datasets.ImageFolder( + os.path.join(dataset_root, 'Final_Training', 'Images'), transform=pre_process) + + dataset_size = len(gtsrb_dataset) + indices = list(range(dataset_size)) + if shuffle_dataset: + np.random.seed(random_seed) + np.random.shuffle(indices) + train_indices, val_indices = indices[:train_size], indices[train_size:] + + # Creating PT data samplers and loaders: + train_sampler = SubsetRandomSampler(train_indices) + valid_sampler = SubsetRandomSampler(val_indices) + + if train: + gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, + sampler=train_sampler) + return gtsrb_dataloader_train + else: + gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, + sampler=valid_sampler) + return gtsrb_dataloader_test \ No newline at end of file