wogong
7 years ago
6 changed files with 241 additions and 31 deletions
@ -0,0 +1,30 @@ |
|||||
|
"""Dataset setting and data loader for Office.""" |
||||
|
|
||||
|
import torch |
||||
|
from torchvision import datasets, transforms |
||||
|
import torch.utils.data as data |
||||
|
import os |
||||
|
import params |
||||
|
|
||||
|
|
||||
|
def get_office(train, category): |
||||
|
"""Get Office datasets loader.""" |
||||
|
# image pre-processing |
||||
|
pre_process = transforms.Compose([transforms.Resize(params.office_image_size), |
||||
|
transforms.ToTensor(), |
||||
|
transforms.Normalize( |
||||
|
mean=params.imagenet_dataset_mean, |
||||
|
std=params.imagenet_dataset_mean)]) |
||||
|
|
||||
|
# datasets and data_loader |
||||
|
office_dataset = datasets.ImageFolder( |
||||
|
os.path.join(params.dataset_root, 'office', category, 'images'), |
||||
|
transform=pre_process) |
||||
|
|
||||
|
office_dataloader = torch.utils.data.DataLoader( |
||||
|
dataset=office_dataset, |
||||
|
batch_size=params.batch_size, |
||||
|
shuffle=True, |
||||
|
num_workers=8) |
||||
|
|
||||
|
return office_dataloader |
@ -0,0 +1,30 @@ |
|||||
|
"""Dataset setting and data loader for Office_Caltech_10.""" |
||||
|
|
||||
|
import torch |
||||
|
from torchvision import datasets, transforms |
||||
|
import torch.utils.data as data |
||||
|
import os |
||||
|
import params |
||||
|
|
||||
|
|
||||
|
def get_officecaltech(train, category): |
||||
|
"""Get Office_Caltech_10 datasets loader.""" |
||||
|
# image pre-processing |
||||
|
pre_process = transforms.Compose([transforms.Resize(params.office_image_size), |
||||
|
transforms.ToTensor(), |
||||
|
transforms.Normalize( |
||||
|
mean=params.imagenet_dataset_mean, |
||||
|
std=params.imagenet_dataset_mean)]) |
||||
|
|
||||
|
# datasets and data_loader |
||||
|
officecaltech_dataset = datasets.ImageFolder( |
||||
|
os.path.join(params.dataset_root, 'office_caltech_10', category), |
||||
|
transform=pre_process) |
||||
|
|
||||
|
officecaltech_dataloader = torch.utils.data.DataLoader( |
||||
|
dataset=officecaltech_dataset, |
||||
|
batch_size=params.batch_size, |
||||
|
shuffle=True, |
||||
|
num_workers=8) |
||||
|
|
||||
|
return officecaltech_dataloader |
@ -0,0 +1,32 @@ |
|||||
|
from models.model import SVHNmodel, Classifier |
||||
|
|
||||
|
from core.dann import train_dann |
||||
|
from core.test import eval |
||||
|
from models.model import AlexModel |
||||
|
|
||||
|
import params |
||||
|
from utils import get_data_loader, init_model, init_random_seed |
||||
|
|
||||
|
# init random seed |
||||
|
init_random_seed(params.manual_seed) |
||||
|
|
||||
|
# load dataset |
||||
|
src_data_loader = get_data_loader(params.src_dataset) |
||||
|
tgt_data_loader = get_data_loader(params.tgt_dataset) |
||||
|
|
||||
|
# load dann model |
||||
|
dann = init_model(net=AlexModel(), restore=params.dann_restore) |
||||
|
|
||||
|
# train dann model |
||||
|
print("Start training dann model.") |
||||
|
|
||||
|
if not (dann.restored and params.dann_restore): |
||||
|
dann = train_dann(dann, src_data_loader, tgt_data_loader, tgt_data_loader) |
||||
|
|
||||
|
# eval dann model |
||||
|
print("Evaluating dann for source domain") |
||||
|
eval(dann, src_data_loader) |
||||
|
print("Evaluating dann for target domain") |
||||
|
eval(dann, tgt_data_loader) |
||||
|
|
||||
|
print('done') |
Loading…
Reference in new issue