A PyTorch implementation for paper Unsupervised Domain Adaptation by Backpropagation
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

67 lines
2.2 KiB

"""Dataset setting and data loader for GTSRB. Pickle format and use roi info.
"""
import os
import torch
from torchvision import transforms
import torch.utils.data as data
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import pickle
from PIL import Image
class GTSRB(data.Dataset):
def __init__(self, filepath, transform=None):
with open(filepath,'rb') as f:
self.data = pickle.load(f)
self.keys = ['images', 'labels']
self.images = self.data[self.keys[0]]
self.labels = self.data[self.keys[1]]
self.transform = transform
self.n_data = len(self.labels)
def __getitem__(self, index):
image, label = self.images[index], self.labels[index]
image = Image.fromarray(np.uint8(image))
if self.transform is not None:
image = self.transform(image)
label = int(label)
return image, label
def __len__(self):
return self.n_data
def get_gtsrb(dataset_root, batch_size, train):
"""Get GTSRB datasets loader."""
shuffle_dataset = True
random_seed = 42
train_size = 31367
# image pre-processing
pre_process = transforms.Compose([
transforms.Resize((40, 40)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# datasets and data_loader
gtsrb_dataset = GTSRB(os.path.join(dataset_root, 'gtsrb_train.p'), transform=pre_process)
dataset_size = len(gtsrb_dataset)
indices = list(range(dataset_size))
if shuffle_dataset:
#np.random.seed(random_seed)
np.random.seed()
np.random.shuffle(indices)
train_indices, val_indices = indices[:train_size], indices[train_size:]
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
if train:
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, sampler=train_sampler)
return gtsrb_dataloader_train
else:
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, sampler=valid_sampler)
return gtsrb_dataloader_test