wogong
7 years ago
15 changed files with 262 additions and 259 deletions
@ -1,49 +0,0 @@ |
|||
from models.model import SVHNmodel, Classifier |
|||
|
|||
from core.dann import train_dann |
|||
from core.test import eval, eval_src |
|||
from core.pretrain import train_src |
|||
|
|||
import params |
|||
from utils import get_data_loader, init_model, init_random_seed |
|||
|
|||
# init random seed |
|||
init_random_seed(params.manual_seed) |
|||
|
|||
# load dataset |
|||
src_data_loader = get_data_loader(params.src_dataset) |
|||
src_data_loader_eval = get_data_loader(params.src_dataset, train=False) |
|||
tgt_data_loader = get_data_loader(params.tgt_dataset) |
|||
tgt_data_loader_eval = get_data_loader(params.tgt_dataset, train=False) |
|||
|
|||
# load source classifier |
|||
src_classifier = init_model(net=Classifier(), restore=params.src_classifier_restore) |
|||
|
|||
# train source model |
|||
print("=== Training classifier for source domain ===") |
|||
|
|||
if not (src_classifier.restored and params.src_model_trained): |
|||
src_classifier = train_src(src_classifier, src_data_loader) |
|||
|
|||
# eval source model on both source and target domain |
|||
print("=== Evaluating source classifier for source domain ===") |
|||
eval_src(src_classifier, src_data_loader_eval) |
|||
print("=== Evaluating source classifier for target domain ===") |
|||
eval_src(src_classifier, tgt_data_loader_eval) |
|||
|
|||
# load dann model |
|||
dann = init_model(net=SVHNmodel(), restore=params.dann_restore) |
|||
|
|||
# train dann model |
|||
print("=== Training dann model ===") |
|||
|
|||
if not (dann.restored and params.dann_restore): |
|||
dann = train_dann(dann, src_data_loader, tgt_data_loader, tgt_data_loader_eval) |
|||
w |
|||
# eval dann model |
|||
print("=== Evaluating dann for source domain ===") |
|||
eval(dann, src_data_loader_eval) |
|||
print("=== Evaluating dann for target domain ===") |
|||
eval(dann, tgt_data_loader_eval) |
|||
|
|||
print('done') |
@ -0,0 +1,76 @@ |
|||
import os |
|||
|
|||
from models.model import MNISTmodel |
|||
from core.dann import train_dann |
|||
from utils import get_data_loader, init_model, init_random_seed |
|||
|
|||
|
|||
class Config(object): |
|||
# params for path |
|||
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) |
|||
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN')) |
|||
|
|||
# params for datasets and data loader |
|||
batch_size = 128 |
|||
|
|||
# params for source dataset |
|||
src_dataset = "mnist" |
|||
src_model_trained = True |
|||
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt') |
|||
class_num_src = 31 |
|||
|
|||
# params for target dataset |
|||
tgt_dataset = "mnistm" |
|||
tgt_model_trained = True |
|||
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt') |
|||
|
|||
# params for pretrain |
|||
num_epochs_src = 100 |
|||
log_step_src = 10 |
|||
save_step_src = 50 |
|||
eval_step_src = 20 |
|||
|
|||
# params for training dann |
|||
|
|||
## for digit |
|||
num_epochs = 100 |
|||
log_step = 20 |
|||
save_step = 50 |
|||
eval_step = 5 |
|||
|
|||
## for office |
|||
# num_epochs = 1000 |
|||
# log_step = 10 # iters |
|||
# save_step = 500 |
|||
# eval_step = 5 # epochs |
|||
|
|||
manual_seed = 8888 |
|||
alpha = 0 |
|||
|
|||
# params for optimizing models |
|||
lr = 2e-4 |
|||
|
|||
params = Config() |
|||
|
|||
# init random seed |
|||
init_random_seed(params.manual_seed) |
|||
|
|||
# load dataset |
|||
src_data_loader = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=True) |
|||
src_data_loader_eval = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=False) |
|||
tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=True) |
|||
tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=False) |
|||
|
|||
# load dann model |
|||
dann = init_model(net=MNISTmodel(), restore=None) |
|||
|
|||
# train dann model |
|||
print("Training dann model") |
|||
if not (dann.restored and params.dann_restore): |
|||
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval) |
|||
|
|||
# eval dann model |
|||
print("Evaluating dann for source domain {}".format(params.src_dataset)) |
|||
eval(dann, src_data_loader_eval) |
|||
print("Evaluating dann for target domain {}".format(params.tgt_dataset)) |
|||
eval(dann, tgt_data_loader_eval) |
@ -1,50 +0,0 @@ |
|||
"""Params for DANN.""" |
|||
|
|||
import os |
|||
|
|||
# params for path |
|||
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) |
|||
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN')) |
|||
|
|||
# params for datasets and data loader |
|||
|
|||
batch_size = 64 |
|||
|
|||
office_image_size = 227 |
|||
|
|||
# params for source dataset |
|||
src_dataset = "amazon31" |
|||
src_model_trained = True |
|||
src_classifier_restore = os.path.join(model_root,src_dataset + '-source-classifier-final.pt') |
|||
class_num_src = 31 |
|||
|
|||
# params for target dataset |
|||
tgt_dataset = "webcam31" |
|||
tgt_model_trained = True |
|||
dann_restore = os.path.join(model_root , src_dataset + '-' + tgt_dataset + '-dann-final.pt') |
|||
|
|||
# params for pretrain |
|||
num_epochs_src = 100 |
|||
log_step_src = 10 |
|||
save_step_src = 20 |
|||
eval_step_src = 20 |
|||
|
|||
# params for training dann |
|||
|
|||
## for digit |
|||
# num_epochs = 400 |
|||
# log_step = 100 |
|||
# save_step = 20 |
|||
# eval_step = 20 |
|||
|
|||
## for office |
|||
num_epochs = 1000 |
|||
log_step = 10 # iters |
|||
save_step = 500 |
|||
eval_step = 5 # epochs |
|||
|
|||
manual_seed = 8888 |
|||
alpha = 0 |
|||
|
|||
# params for optimizing models |
|||
lr = 2e-4 |
@ -0,0 +1,75 @@ |
|||
import os |
|||
|
|||
from models.model import SVHNmodel |
|||
from core.dann import train_dann |
|||
from utils import get_data_loader, init_model, init_random_seed |
|||
|
|||
|
|||
class Config(object): |
|||
# params for path |
|||
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) |
|||
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN')) |
|||
|
|||
# params for datasets and data loader |
|||
batch_size = 128 |
|||
|
|||
# params for source dataset |
|||
src_dataset = "svhn" |
|||
src_model_trained = True |
|||
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt') |
|||
|
|||
# params for target dataset |
|||
tgt_dataset = "mnist" |
|||
tgt_model_trained = True |
|||
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt') |
|||
|
|||
# params for pretrain |
|||
num_epochs_src = 100 |
|||
log_step_src = 10 |
|||
save_step_src = 50 |
|||
eval_step_src = 20 |
|||
|
|||
# params for training dann |
|||
|
|||
## for digit |
|||
num_epochs = 200 |
|||
log_step = 20 |
|||
save_step = 50 |
|||
eval_step = 5 |
|||
|
|||
## for office |
|||
# num_epochs = 1000 |
|||
# log_step = 10 # iters |
|||
# save_step = 500 |
|||
# eval_step = 5 # epochs |
|||
|
|||
manual_seed = 8888 |
|||
alpha = 0 |
|||
|
|||
# params for optimizing models |
|||
lr = 2e-4 |
|||
|
|||
params = Config() |
|||
|
|||
# init random seed |
|||
init_random_seed(params.manual_seed) |
|||
|
|||
# load dataset |
|||
src_data_loader = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=True) |
|||
src_data_loader_eval = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=False) |
|||
tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=True) |
|||
tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=False) |
|||
|
|||
# load dann model |
|||
dann = init_model(net=SVHNmodel(), restore=None) |
|||
|
|||
# train dann model |
|||
print("Training dann model") |
|||
if not (dann.restored and params.dann_restore): |
|||
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval) |
|||
|
|||
# eval dann model |
|||
print("Evaluating dann for source domain {}".format(params.src_dataset)) |
|||
eval(dann, src_data_loader_eval) |
|||
print("Evaluating dann for target domain {}".format(params.tgt_dataset)) |
|||
eval(dann, tgt_data_loader_eval) |
Loading…
Reference in new issue