A PyTorch implementation for paper Unsupervised Domain Adaptation by Backpropagation
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

25 lines
796 B

"""Dataset setting and data loader for Office."""
6 years ago
import os
import torch
from torchvision import datasets, transforms
import torch.utils.data as data
def get_office(dataset_root, batch_size, category):
"""Get Office datasets loader."""
# image pre-processing
6 years ago
pre_process = transforms.Compose([
transforms.Resize(227),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
# datasets and data_loader
office_dataset = datasets.ImageFolder(
os.path.join(dataset_root, 'office31', category, 'images'), transform=pre_process)
office_dataloader = torch.utils.data.DataLoader(
6 years ago
dataset=office_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
return office_dataloader