wogong
5 years ago
1 changed files with 45 additions and 0 deletions
@ -0,0 +1,45 @@ |
|||||
|
"""Dataset setting and data loader for GTSRB.""" |
||||
|
|
||||
|
import os |
||||
|
import torch |
||||
|
from torchvision import datasets, transforms |
||||
|
import torch.utils.data as data |
||||
|
from torch.utils.data.sampler import SubsetRandomSampler |
||||
|
import numpy as np |
||||
|
|
||||
|
def get_gtsrb(dataset_root, batch_size, train): |
||||
|
"""Get GTSRB datasets loader.""" |
||||
|
shuffle_dataset = True |
||||
|
random_seed = 42 |
||||
|
train_size = 31367 |
||||
|
|
||||
|
# image pre-processing |
||||
|
pre_process = transforms.Compose([ |
||||
|
transforms.Resize((40, 40)), |
||||
|
transforms.ToTensor(), |
||||
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) |
||||
|
]) |
||||
|
|
||||
|
# datasets and data_loader |
||||
|
gtsrb_dataset = datasets.ImageFolder( |
||||
|
os.path.join(dataset_root, 'Final_Training', 'Images'), transform=pre_process) |
||||
|
|
||||
|
dataset_size = len(gtsrb_dataset) |
||||
|
indices = list(range(dataset_size)) |
||||
|
if shuffle_dataset: |
||||
|
np.random.seed(random_seed) |
||||
|
np.random.shuffle(indices) |
||||
|
train_indices, val_indices = indices[:train_size], indices[train_size:] |
||||
|
|
||||
|
# Creating PT data samplers and loaders: |
||||
|
train_sampler = SubsetRandomSampler(train_indices) |
||||
|
valid_sampler = SubsetRandomSampler(val_indices) |
||||
|
|
||||
|
if train: |
||||
|
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, |
||||
|
sampler=train_sampler) |
||||
|
return gtsrb_dataloader_train |
||||
|
else: |
||||
|
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, |
||||
|
sampler=valid_sampler) |
||||
|
return gtsrb_dataloader_test |
Loading…
Reference in new issue