|
@ -3,7 +3,7 @@ import sys |
|
|
|
|
|
|
|
|
sys.path.append('../') |
|
|
sys.path.append('../') |
|
|
from core.dann import train_dann |
|
|
from core.dann import train_dann |
|
|
from core.test import eval |
|
|
|
|
|
|
|
|
from core.test import test |
|
|
from models.model import AlexModel |
|
|
from models.model import AlexModel |
|
|
|
|
|
|
|
|
from utils.utils import get_data_loader, init_model, init_random_seed |
|
|
from utils.utils import get_data_loader, init_model, init_random_seed |
|
@ -12,8 +12,7 @@ from utils.utils import get_data_loader, init_model, init_random_seed |
|
|
class Config(object): |
|
|
class Config(object): |
|
|
# params for path |
|
|
# params for path |
|
|
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) |
|
|
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) |
|
|
model_root = os.path.expanduser( |
|
|
|
|
|
os.path.join('~', 'Models', 'pytorch-DANN')) |
|
|
|
|
|
|
|
|
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN')) |
|
|
|
|
|
|
|
|
# params for datasets and data loader |
|
|
# params for datasets and data loader |
|
|
batch_size = 32 |
|
|
batch_size = 32 |
|
@ -21,14 +20,12 @@ class Config(object): |
|
|
# params for source dataset |
|
|
# params for source dataset |
|
|
src_dataset = "amazon31" |
|
|
src_dataset = "amazon31" |
|
|
src_model_trained = True |
|
|
src_model_trained = True |
|
|
src_classifier_restore = os.path.join( |
|
|
|
|
|
model_root, src_dataset + '-source-classifier-final.pt') |
|
|
|
|
|
|
|
|
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt') |
|
|
|
|
|
|
|
|
# params for target dataset |
|
|
# params for target dataset |
|
|
tgt_dataset = "webcam10" |
|
|
tgt_dataset = "webcam10" |
|
|
tgt_model_trained = True |
|
|
tgt_model_trained = True |
|
|
dann_restore = os.path.join( |
|
|
|
|
|
model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt') |
|
|
|
|
|
|
|
|
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt') |
|
|
|
|
|
|
|
|
# params for pretrain |
|
|
# params for pretrain |
|
|
num_epochs_src = 100 |
|
|
num_epochs_src = 100 |
|
@ -57,10 +54,8 @@ params = Config() |
|
|
init_random_seed(params.manual_seed) |
|
|
init_random_seed(params.manual_seed) |
|
|
|
|
|
|
|
|
# load dataset |
|
|
# load dataset |
|
|
src_data_loader = get_data_loader( |
|
|
|
|
|
params.src_dataset, params.dataset_root, params.batch_size) |
|
|
|
|
|
tgt_data_loader = get_data_loader( |
|
|
|
|
|
params.tgt_dataset, params.dataset_root, params.batch_size) |
|
|
|
|
|
|
|
|
src_data_loader = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size) |
|
|
|
|
|
tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size) |
|
|
|
|
|
|
|
|
# load dann model |
|
|
# load dann model |
|
|
dann = init_model(net=AlexModel(), restore=None) |
|
|
dann = init_model(net=AlexModel(), restore=None) |
|
@ -69,13 +64,6 @@ dann = init_model(net=AlexModel(), restore=None) |
|
|
print("Start training dann model.") |
|
|
print("Start training dann model.") |
|
|
|
|
|
|
|
|
if not (dann.restored and params.dann_restore): |
|
|
if not (dann.restored and params.dann_restore): |
|
|
dann = train_dann(dann, params, src_data_loader, |
|
|
|
|
|
tgt_data_loader, tgt_data_loader) |
|
|
|
|
|
|
|
|
|
|
|
# eval dann model |
|
|
|
|
|
print("Evaluating dann for source domain") |
|
|
|
|
|
eval(dann, src_data_loader) |
|
|
|
|
|
print("Evaluating dann for target domain") |
|
|
|
|
|
eval(dann, tgt_data_loader) |
|
|
|
|
|
|
|
|
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader) |
|
|
|
|
|
|
|
|
print('done') |
|
|
print('done') |
|
|