A PyTorch implementation for paper Unsupervised Domain Adaptation by Backpropagation
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

40 lines
1.3 KiB

7 years ago
"""Dataset setting and data loader for SVHN."""
import torch
from torchvision import datasets, transforms
import os
import params
def get_svhn(train):
"""Get SVHN datasets loader."""
# image pre-processing
pre_process = transforms.Compose([transforms.Grayscale(),
7 years ago
transforms.Resize(params.digit_image_size),
7 years ago
transforms.ToTensor(),
transforms.Normalize(
mean=params.dataset_mean,
std=params.dataset_std)])
# datasets and data loader
if train:
svhn_dataset = datasets.SVHN(root=os.path.join(params.dataset_root,'svhn'),
split='train',
transform=pre_process,
download=True)
else:
svhn_dataset = datasets.SVHN(root=os.path.join(params.dataset_root,'svhn'),
split='test',
transform=pre_process,
download=True)
svhn_data_loader = torch.utils.data.DataLoader(
dataset=svhn_dataset,
batch_size=params.batch_size,
shuffle=True,
drop_last=True)
return svhn_data_loader