wogong
7 years ago
6 changed files with 241 additions and 31 deletions
@ -0,0 +1,30 @@ |
|||
"""Dataset setting and data loader for Office.""" |
|||
|
|||
import torch |
|||
from torchvision import datasets, transforms |
|||
import torch.utils.data as data |
|||
import os |
|||
import params |
|||
|
|||
|
|||
def get_office(train, category): |
|||
"""Get Office datasets loader.""" |
|||
# image pre-processing |
|||
pre_process = transforms.Compose([transforms.Resize(params.office_image_size), |
|||
transforms.ToTensor(), |
|||
transforms.Normalize( |
|||
mean=params.imagenet_dataset_mean, |
|||
std=params.imagenet_dataset_mean)]) |
|||
|
|||
# datasets and data_loader |
|||
office_dataset = datasets.ImageFolder( |
|||
os.path.join(params.dataset_root, 'office', category, 'images'), |
|||
transform=pre_process) |
|||
|
|||
office_dataloader = torch.utils.data.DataLoader( |
|||
dataset=office_dataset, |
|||
batch_size=params.batch_size, |
|||
shuffle=True, |
|||
num_workers=8) |
|||
|
|||
return office_dataloader |
@ -0,0 +1,30 @@ |
|||
"""Dataset setting and data loader for Office_Caltech_10.""" |
|||
|
|||
import torch |
|||
from torchvision import datasets, transforms |
|||
import torch.utils.data as data |
|||
import os |
|||
import params |
|||
|
|||
|
|||
def get_officecaltech(train, category): |
|||
"""Get Office_Caltech_10 datasets loader.""" |
|||
# image pre-processing |
|||
pre_process = transforms.Compose([transforms.Resize(params.office_image_size), |
|||
transforms.ToTensor(), |
|||
transforms.Normalize( |
|||
mean=params.imagenet_dataset_mean, |
|||
std=params.imagenet_dataset_mean)]) |
|||
|
|||
# datasets and data_loader |
|||
officecaltech_dataset = datasets.ImageFolder( |
|||
os.path.join(params.dataset_root, 'office_caltech_10', category), |
|||
transform=pre_process) |
|||
|
|||
officecaltech_dataloader = torch.utils.data.DataLoader( |
|||
dataset=officecaltech_dataset, |
|||
batch_size=params.batch_size, |
|||
shuffle=True, |
|||
num_workers=8) |
|||
|
|||
return officecaltech_dataloader |
@ -0,0 +1,32 @@ |
|||
from models.model import SVHNmodel, Classifier |
|||
|
|||
from core.dann import train_dann |
|||
from core.test import eval |
|||
from models.model import AlexModel |
|||
|
|||
import params |
|||
from utils import get_data_loader, init_model, init_random_seed |
|||
|
|||
# init random seed |
|||
init_random_seed(params.manual_seed) |
|||
|
|||
# load dataset |
|||
src_data_loader = get_data_loader(params.src_dataset) |
|||
tgt_data_loader = get_data_loader(params.tgt_dataset) |
|||
|
|||
# load dann model |
|||
dann = init_model(net=AlexModel(), restore=params.dann_restore) |
|||
|
|||
# train dann model |
|||
print("Start training dann model.") |
|||
|
|||
if not (dann.restored and params.dann_restore): |
|||
dann = train_dann(dann, src_data_loader, tgt_data_loader, tgt_data_loader) |
|||
|
|||
# eval dann model |
|||
print("Evaluating dann for source domain") |
|||
eval(dann, src_data_loader) |
|||
print("Evaluating dann for target domain") |
|||
eval(dann, tgt_data_loader) |
|||
|
|||
print('done') |
Loading…
Reference in new issue