A PyTorch implementation for paper Unsupervised Domain Adaptation by Backpropagation
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

40 lines
1.4 KiB

"""Dataset setting and data loader for SVHN."""
import torch
from torchvision import datasets, transforms
import os
import params
def get_svhn(train):
"""Get SVHN datasets loader."""
# image pre-processing
pre_process = transforms.Compose([transforms.Grayscale(),
transforms.Resize(params.digit_image_size),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5)
)])
# datasets and data loader
if train:
svhn_dataset = datasets.SVHN(root=os.path.join(params.dataset_root,'svhn'),
split='train',
transform=pre_process,
download=True)
else:
svhn_dataset = datasets.SVHN(root=os.path.join(params.dataset_root,'svhn'),
split='test',
transform=pre_process,
download=True)
svhn_data_loader = torch.utils.data.DataLoader(
dataset=svhn_dataset,
batch_size=params.batch_size,
shuffle=True,
drop_last=True)
return svhn_data_loader