A PyTorch implementation for paper Unsupervised Domain Adaptation by Backpropagation
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

45 lines
1.6 KiB

"""Dataset setting and data loader for GTSRB."""
import os
import torch
from torchvision import datasets, transforms
import torch.utils.data as data
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
def get_gtsrb(dataset_root, batch_size, train):
"""Get GTSRB datasets loader."""
shuffle_dataset = True
random_seed = 42
train_size = 31367
# image pre-processing
pre_process = transforms.Compose([
transforms.Resize((48, 48)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# datasets and data_loader
gtsrb_dataset = datasets.ImageFolder(
os.path.join(dataset_root, 'Final_Training', 'Images'), transform=pre_process)
dataset_size = len(gtsrb_dataset)
indices = list(range(dataset_size))
if shuffle_dataset:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[:train_size], indices[train_size:]
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
if train:
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size,
sampler=train_sampler)
return gtsrb_dataloader_train
else:
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size,
sampler=valid_sampler)
return gtsrb_dataloader_test