wogong
6 years ago
11 changed files with 190 additions and 185 deletions
@ -1,30 +1,25 @@ |
|||
"""Dataset setting and data loader for Office.""" |
|||
|
|||
import os |
|||
import torch |
|||
from torchvision import datasets, transforms |
|||
import torch.utils.data as data |
|||
import os |
|||
|
|||
|
|||
def get_office(dataset_root, batch_size, category): |
|||
"""Get Office datasets loader.""" |
|||
# image pre-processing |
|||
pre_process = transforms.Compose([transforms.Resize(227), |
|||
transforms.ToTensor(), |
|||
transforms.Normalize( |
|||
mean=(0.485, 0.456, 0.406), |
|||
std=(0.229, 0.224, 0.225) |
|||
)]) |
|||
pre_process = transforms.Compose([ |
|||
transforms.Resize(227), |
|||
transforms.ToTensor(), |
|||
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) |
|||
]) |
|||
|
|||
# datasets and data_loader |
|||
office_dataset = datasets.ImageFolder( |
|||
os.path.join(dataset_root, 'office', category, 'images'), |
|||
transform=pre_process) |
|||
os.path.join(dataset_root, 'office', category, 'images'), transform=pre_process) |
|||
|
|||
office_dataloader = torch.utils.data.DataLoader( |
|||
dataset=office_dataset, |
|||
batch_size=batch_size, |
|||
shuffle=True, |
|||
num_workers=4) |
|||
dataset=office_dataset, batch_size=batch_size, shuffle=True, num_workers=0) |
|||
|
|||
return office_dataloader |
@ -1,8 +1,10 @@ |
|||
import os |
|||
import sys |
|||
|
|||
sys.path.append('../') |
|||
from models.model import MNISTmodel |
|||
from core.dann import train_dann |
|||
from utils import get_data_loader, init_model, init_random_seed |
|||
from utils.utils import get_data_loader, init_model, init_random_seed |
|||
|
|||
|
|||
class Config(object): |
@ -0,0 +1,81 @@ |
|||
import os |
|||
import sys |
|||
|
|||
sys.path.append('../') |
|||
from core.dann import train_dann |
|||
from core.test import eval |
|||
from models.model import AlexModel |
|||
|
|||
from utils.utils import get_data_loader, init_model, init_random_seed |
|||
|
|||
|
|||
class Config(object): |
|||
# params for path |
|||
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) |
|||
model_root = os.path.expanduser( |
|||
os.path.join('~', 'Models', 'pytorch-DANN')) |
|||
|
|||
# params for datasets and data loader |
|||
batch_size = 32 |
|||
|
|||
# params for source dataset |
|||
src_dataset = "amazon31" |
|||
src_model_trained = True |
|||
src_classifier_restore = os.path.join( |
|||
model_root, src_dataset + '-source-classifier-final.pt') |
|||
|
|||
# params for target dataset |
|||
tgt_dataset = "webcam10" |
|||
tgt_model_trained = True |
|||
dann_restore = os.path.join( |
|||
model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt') |
|||
|
|||
# params for pretrain |
|||
num_epochs_src = 100 |
|||
log_step_src = 5 |
|||
save_step_src = 50 |
|||
eval_step_src = 20 |
|||
|
|||
# params for training dann |
|||
|
|||
# for office |
|||
num_epochs = 1000 |
|||
log_step = 10 # iters |
|||
save_step = 500 |
|||
eval_step = 5 # epochs |
|||
|
|||
manual_seed = 8888 |
|||
alpha = 0 |
|||
|
|||
# params for optimizing models |
|||
lr = 2e-4 |
|||
|
|||
|
|||
params = Config() |
|||
|
|||
# init random seed |
|||
init_random_seed(params.manual_seed) |
|||
|
|||
# load dataset |
|||
src_data_loader = get_data_loader( |
|||
params.src_dataset, params.dataset_root, params.batch_size) |
|||
tgt_data_loader = get_data_loader( |
|||
params.tgt_dataset, params.dataset_root, params.batch_size) |
|||
|
|||
# load dann model |
|||
dann = init_model(net=AlexModel(), restore=None) |
|||
|
|||
# train dann model |
|||
print("Start training dann model.") |
|||
|
|||
if not (dann.restored and params.dann_restore): |
|||
dann = train_dann(dann, params, src_data_loader, |
|||
tgt_data_loader, tgt_data_loader) |
|||
|
|||
# eval dann model |
|||
print("Evaluating dann for source domain") |
|||
eval(dann, src_data_loader) |
|||
print("Evaluating dann for target domain") |
|||
eval(dann, tgt_data_loader) |
|||
|
|||
print('done') |
@ -1,8 +1,10 @@ |
|||
import os |
|||
import sys |
|||
|
|||
sys.path.append('../') |
|||
from models.model import SVHNmodel |
|||
from core.dann import train_dann |
|||
from utils import get_data_loader, init_model, init_random_seed |
|||
from utils.utils import get_data_loader, init_model, init_random_seed |
|||
|
|||
|
|||
class Config(object): |
Loading…
Reference in new issue