wogong
7 years ago
15 changed files with 262 additions and 259 deletions
@ -1,49 +0,0 @@ |
|||||
from models.model import SVHNmodel, Classifier |
|
||||
|
|
||||
from core.dann import train_dann |
|
||||
from core.test import eval, eval_src |
|
||||
from core.pretrain import train_src |
|
||||
|
|
||||
import params |
|
||||
from utils import get_data_loader, init_model, init_random_seed |
|
||||
|
|
||||
# init random seed |
|
||||
init_random_seed(params.manual_seed) |
|
||||
|
|
||||
# load dataset |
|
||||
src_data_loader = get_data_loader(params.src_dataset) |
|
||||
src_data_loader_eval = get_data_loader(params.src_dataset, train=False) |
|
||||
tgt_data_loader = get_data_loader(params.tgt_dataset) |
|
||||
tgt_data_loader_eval = get_data_loader(params.tgt_dataset, train=False) |
|
||||
|
|
||||
# load source classifier |
|
||||
src_classifier = init_model(net=Classifier(), restore=params.src_classifier_restore) |
|
||||
|
|
||||
# train source model |
|
||||
print("=== Training classifier for source domain ===") |
|
||||
|
|
||||
if not (src_classifier.restored and params.src_model_trained): |
|
||||
src_classifier = train_src(src_classifier, src_data_loader) |
|
||||
|
|
||||
# eval source model on both source and target domain |
|
||||
print("=== Evaluating source classifier for source domain ===") |
|
||||
eval_src(src_classifier, src_data_loader_eval) |
|
||||
print("=== Evaluating source classifier for target domain ===") |
|
||||
eval_src(src_classifier, tgt_data_loader_eval) |
|
||||
|
|
||||
# load dann model |
|
||||
dann = init_model(net=SVHNmodel(), restore=params.dann_restore) |
|
||||
|
|
||||
# train dann model |
|
||||
print("=== Training dann model ===") |
|
||||
|
|
||||
if not (dann.restored and params.dann_restore): |
|
||||
dann = train_dann(dann, src_data_loader, tgt_data_loader, tgt_data_loader_eval) |
|
||||
w |
|
||||
# eval dann model |
|
||||
print("=== Evaluating dann for source domain ===") |
|
||||
eval(dann, src_data_loader_eval) |
|
||||
print("=== Evaluating dann for target domain ===") |
|
||||
eval(dann, tgt_data_loader_eval) |
|
||||
|
|
||||
print('done') |
|
@ -0,0 +1,76 @@ |
|||||
|
import os |
||||
|
|
||||
|
from models.model import MNISTmodel |
||||
|
from core.dann import train_dann |
||||
|
from utils import get_data_loader, init_model, init_random_seed |
||||
|
|
||||
|
|
||||
|
class Config(object): |
||||
|
# params for path |
||||
|
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) |
||||
|
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN')) |
||||
|
|
||||
|
# params for datasets and data loader |
||||
|
batch_size = 128 |
||||
|
|
||||
|
# params for source dataset |
||||
|
src_dataset = "mnist" |
||||
|
src_model_trained = True |
||||
|
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt') |
||||
|
class_num_src = 31 |
||||
|
|
||||
|
# params for target dataset |
||||
|
tgt_dataset = "mnistm" |
||||
|
tgt_model_trained = True |
||||
|
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt') |
||||
|
|
||||
|
# params for pretrain |
||||
|
num_epochs_src = 100 |
||||
|
log_step_src = 10 |
||||
|
save_step_src = 50 |
||||
|
eval_step_src = 20 |
||||
|
|
||||
|
# params for training dann |
||||
|
|
||||
|
## for digit |
||||
|
num_epochs = 100 |
||||
|
log_step = 20 |
||||
|
save_step = 50 |
||||
|
eval_step = 5 |
||||
|
|
||||
|
## for office |
||||
|
# num_epochs = 1000 |
||||
|
# log_step = 10 # iters |
||||
|
# save_step = 500 |
||||
|
# eval_step = 5 # epochs |
||||
|
|
||||
|
manual_seed = 8888 |
||||
|
alpha = 0 |
||||
|
|
||||
|
# params for optimizing models |
||||
|
lr = 2e-4 |
||||
|
|
||||
|
params = Config() |
||||
|
|
||||
|
# init random seed |
||||
|
init_random_seed(params.manual_seed) |
||||
|
|
||||
|
# load dataset |
||||
|
src_data_loader = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=True) |
||||
|
src_data_loader_eval = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=False) |
||||
|
tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=True) |
||||
|
tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=False) |
||||
|
|
||||
|
# load dann model |
||||
|
dann = init_model(net=MNISTmodel(), restore=None) |
||||
|
|
||||
|
# train dann model |
||||
|
print("Training dann model") |
||||
|
if not (dann.restored and params.dann_restore): |
||||
|
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval) |
||||
|
|
||||
|
# eval dann model |
||||
|
print("Evaluating dann for source domain {}".format(params.src_dataset)) |
||||
|
eval(dann, src_data_loader_eval) |
||||
|
print("Evaluating dann for target domain {}".format(params.tgt_dataset)) |
||||
|
eval(dann, tgt_data_loader_eval) |
@ -1,50 +0,0 @@ |
|||||
"""Params for DANN.""" |
|
||||
|
|
||||
import os |
|
||||
|
|
||||
# params for path |
|
||||
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) |
|
||||
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN')) |
|
||||
|
|
||||
# params for datasets and data loader |
|
||||
|
|
||||
batch_size = 64 |
|
||||
|
|
||||
office_image_size = 227 |
|
||||
|
|
||||
# params for source dataset |
|
||||
src_dataset = "amazon31" |
|
||||
src_model_trained = True |
|
||||
src_classifier_restore = os.path.join(model_root,src_dataset + '-source-classifier-final.pt') |
|
||||
class_num_src = 31 |
|
||||
|
|
||||
# params for target dataset |
|
||||
tgt_dataset = "webcam31" |
|
||||
tgt_model_trained = True |
|
||||
dann_restore = os.path.join(model_root , src_dataset + '-' + tgt_dataset + '-dann-final.pt') |
|
||||
|
|
||||
# params for pretrain |
|
||||
num_epochs_src = 100 |
|
||||
log_step_src = 10 |
|
||||
save_step_src = 20 |
|
||||
eval_step_src = 20 |
|
||||
|
|
||||
# params for training dann |
|
||||
|
|
||||
## for digit |
|
||||
# num_epochs = 400 |
|
||||
# log_step = 100 |
|
||||
# save_step = 20 |
|
||||
# eval_step = 20 |
|
||||
|
|
||||
## for office |
|
||||
num_epochs = 1000 |
|
||||
log_step = 10 # iters |
|
||||
save_step = 500 |
|
||||
eval_step = 5 # epochs |
|
||||
|
|
||||
manual_seed = 8888 |
|
||||
alpha = 0 |
|
||||
|
|
||||
# params for optimizing models |
|
||||
lr = 2e-4 |
|
@ -0,0 +1,75 @@ |
|||||
|
import os |
||||
|
|
||||
|
from models.model import SVHNmodel |
||||
|
from core.dann import train_dann |
||||
|
from utils import get_data_loader, init_model, init_random_seed |
||||
|
|
||||
|
|
||||
|
class Config(object): |
||||
|
# params for path |
||||
|
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) |
||||
|
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN')) |
||||
|
|
||||
|
# params for datasets and data loader |
||||
|
batch_size = 128 |
||||
|
|
||||
|
# params for source dataset |
||||
|
src_dataset = "svhn" |
||||
|
src_model_trained = True |
||||
|
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt') |
||||
|
|
||||
|
# params for target dataset |
||||
|
tgt_dataset = "mnist" |
||||
|
tgt_model_trained = True |
||||
|
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt') |
||||
|
|
||||
|
# params for pretrain |
||||
|
num_epochs_src = 100 |
||||
|
log_step_src = 10 |
||||
|
save_step_src = 50 |
||||
|
eval_step_src = 20 |
||||
|
|
||||
|
# params for training dann |
||||
|
|
||||
|
## for digit |
||||
|
num_epochs = 200 |
||||
|
log_step = 20 |
||||
|
save_step = 50 |
||||
|
eval_step = 5 |
||||
|
|
||||
|
## for office |
||||
|
# num_epochs = 1000 |
||||
|
# log_step = 10 # iters |
||||
|
# save_step = 500 |
||||
|
# eval_step = 5 # epochs |
||||
|
|
||||
|
manual_seed = 8888 |
||||
|
alpha = 0 |
||||
|
|
||||
|
# params for optimizing models |
||||
|
lr = 2e-4 |
||||
|
|
||||
|
params = Config() |
||||
|
|
||||
|
# init random seed |
||||
|
init_random_seed(params.manual_seed) |
||||
|
|
||||
|
# load dataset |
||||
|
src_data_loader = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=True) |
||||
|
src_data_loader_eval = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=False) |
||||
|
tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=True) |
||||
|
tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=False) |
||||
|
|
||||
|
# load dann model |
||||
|
dann = init_model(net=SVHNmodel(), restore=None) |
||||
|
|
||||
|
# train dann model |
||||
|
print("Training dann model") |
||||
|
if not (dann.restored and params.dann_restore): |
||||
|
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval) |
||||
|
|
||||
|
# eval dann model |
||||
|
print("Evaluating dann for source domain {}".format(params.src_dataset)) |
||||
|
eval(dann, src_data_loader_eval) |
||||
|
print("Evaluating dann for target domain {}".format(params.tgt_dataset)) |
||||
|
eval(dann, tgt_data_loader_eval) |
Loading…
Reference in new issue