wogong
6 years ago
11 changed files with 190 additions and 185 deletions
@ -1,30 +1,25 @@ |
|||||
"""Dataset setting and data loader for Office.""" |
"""Dataset setting and data loader for Office.""" |
||||
|
|
||||
|
import os |
||||
import torch |
import torch |
||||
from torchvision import datasets, transforms |
from torchvision import datasets, transforms |
||||
import torch.utils.data as data |
import torch.utils.data as data |
||||
import os |
|
||||
|
|
||||
|
|
||||
def get_office(dataset_root, batch_size, category): |
def get_office(dataset_root, batch_size, category): |
||||
"""Get Office datasets loader.""" |
"""Get Office datasets loader.""" |
||||
# image pre-processing |
# image pre-processing |
||||
pre_process = transforms.Compose([transforms.Resize(227), |
|
||||
transforms.ToTensor(), |
|
||||
transforms.Normalize( |
|
||||
mean=(0.485, 0.456, 0.406), |
|
||||
std=(0.229, 0.224, 0.225) |
|
||||
)]) |
|
||||
|
pre_process = transforms.Compose([ |
||||
|
transforms.Resize(227), |
||||
|
transforms.ToTensor(), |
||||
|
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) |
||||
|
]) |
||||
|
|
||||
# datasets and data_loader |
# datasets and data_loader |
||||
office_dataset = datasets.ImageFolder( |
office_dataset = datasets.ImageFolder( |
||||
os.path.join(dataset_root, 'office', category, 'images'), |
|
||||
transform=pre_process) |
|
||||
|
os.path.join(dataset_root, 'office', category, 'images'), transform=pre_process) |
||||
|
|
||||
office_dataloader = torch.utils.data.DataLoader( |
office_dataloader = torch.utils.data.DataLoader( |
||||
dataset=office_dataset, |
|
||||
batch_size=batch_size, |
|
||||
shuffle=True, |
|
||||
num_workers=4) |
|
||||
|
dataset=office_dataset, batch_size=batch_size, shuffle=True, num_workers=0) |
||||
|
|
||||
return office_dataloader |
return office_dataloader |
@ -1,8 +1,10 @@ |
|||||
import os |
import os |
||||
|
import sys |
||||
|
|
||||
|
sys.path.append('../') |
||||
from models.model import MNISTmodel |
from models.model import MNISTmodel |
||||
from core.dann import train_dann |
from core.dann import train_dann |
||||
from utils import get_data_loader, init_model, init_random_seed |
|
||||
|
from utils.utils import get_data_loader, init_model, init_random_seed |
||||
|
|
||||
|
|
||||
class Config(object): |
class Config(object): |
@ -0,0 +1,81 @@ |
|||||
|
import os |
||||
|
import sys |
||||
|
|
||||
|
sys.path.append('../') |
||||
|
from core.dann import train_dann |
||||
|
from core.test import eval |
||||
|
from models.model import AlexModel |
||||
|
|
||||
|
from utils.utils import get_data_loader, init_model, init_random_seed |
||||
|
|
||||
|
|
||||
|
class Config(object): |
||||
|
# params for path |
||||
|
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) |
||||
|
model_root = os.path.expanduser( |
||||
|
os.path.join('~', 'Models', 'pytorch-DANN')) |
||||
|
|
||||
|
# params for datasets and data loader |
||||
|
batch_size = 32 |
||||
|
|
||||
|
# params for source dataset |
||||
|
src_dataset = "amazon31" |
||||
|
src_model_trained = True |
||||
|
src_classifier_restore = os.path.join( |
||||
|
model_root, src_dataset + '-source-classifier-final.pt') |
||||
|
|
||||
|
# params for target dataset |
||||
|
tgt_dataset = "webcam10" |
||||
|
tgt_model_trained = True |
||||
|
dann_restore = os.path.join( |
||||
|
model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt') |
||||
|
|
||||
|
# params for pretrain |
||||
|
num_epochs_src = 100 |
||||
|
log_step_src = 5 |
||||
|
save_step_src = 50 |
||||
|
eval_step_src = 20 |
||||
|
|
||||
|
# params for training dann |
||||
|
|
||||
|
# for office |
||||
|
num_epochs = 1000 |
||||
|
log_step = 10 # iters |
||||
|
save_step = 500 |
||||
|
eval_step = 5 # epochs |
||||
|
|
||||
|
manual_seed = 8888 |
||||
|
alpha = 0 |
||||
|
|
||||
|
# params for optimizing models |
||||
|
lr = 2e-4 |
||||
|
|
||||
|
|
||||
|
params = Config() |
||||
|
|
||||
|
# init random seed |
||||
|
init_random_seed(params.manual_seed) |
||||
|
|
||||
|
# load dataset |
||||
|
src_data_loader = get_data_loader( |
||||
|
params.src_dataset, params.dataset_root, params.batch_size) |
||||
|
tgt_data_loader = get_data_loader( |
||||
|
params.tgt_dataset, params.dataset_root, params.batch_size) |
||||
|
|
||||
|
# load dann model |
||||
|
dann = init_model(net=AlexModel(), restore=None) |
||||
|
|
||||
|
# train dann model |
||||
|
print("Start training dann model.") |
||||
|
|
||||
|
if not (dann.restored and params.dann_restore): |
||||
|
dann = train_dann(dann, params, src_data_loader, |
||||
|
tgt_data_loader, tgt_data_loader) |
||||
|
|
||||
|
# eval dann model |
||||
|
print("Evaluating dann for source domain") |
||||
|
eval(dann, src_data_loader) |
||||
|
print("Evaluating dann for target domain") |
||||
|
eval(dann, tgt_data_loader) |
||||
|
|
||||
|
print('done') |
@ -1,8 +1,10 @@ |
|||||
import os |
import os |
||||
|
import sys |
||||
|
|
||||
|
sys.path.append('../') |
||||
from models.model import SVHNmodel |
from models.model import SVHNmodel |
||||
from core.dann import train_dann |
from core.dann import train_dann |
||||
from utils import get_data_loader, init_model, init_random_seed |
|
||||
|
from utils.utils import get_data_loader, init_model, init_random_seed |
||||
|
|
||||
|
|
||||
class Config(object): |
class Config(object): |
Loading…
Reference in new issue