|
|
@ -1,8 +1,9 @@ |
|
|
|
import os |
|
|
|
import sys |
|
|
|
|
|
|
|
import torch |
|
|
|
sys.path.append('../') |
|
|
|
from models.model import MNISTmodel |
|
|
|
from models.model import MNISTmodel, MNISTmodel_plain |
|
|
|
from core.dann import train_dann |
|
|
|
from utils.utils import get_data_loader, init_model, init_random_seed |
|
|
|
|
|
|
@ -13,7 +14,7 @@ class Config(object): |
|
|
|
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN')) |
|
|
|
|
|
|
|
# params for datasets and data loader |
|
|
|
batch_size = 128 |
|
|
|
batch_size = 64 |
|
|
|
|
|
|
|
# params for source dataset |
|
|
|
src_dataset = "mnist" |
|
|
@ -33,6 +34,7 @@ class Config(object): |
|
|
|
eval_step_src = 20 |
|
|
|
|
|
|
|
# params for training dann |
|
|
|
gpu_id = '0' |
|
|
|
|
|
|
|
## for digit |
|
|
|
num_epochs = 100 |
|
|
@ -58,6 +60,9 @@ params = Config() |
|
|
|
# init random seed |
|
|
|
init_random_seed(params.manual_seed) |
|
|
|
|
|
|
|
# init device |
|
|
|
device = torch.device("cuda:" + params.gpu_id if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
# load dataset |
|
|
|
src_data_loader = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=True) |
|
|
|
src_data_loader_eval = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=False) |
|
|
@ -65,9 +70,9 @@ tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root, param |
|
|
|
tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=False) |
|
|
|
|
|
|
|
# load dann model |
|
|
|
dann = init_model(net=MNISTmodel(), restore=None) |
|
|
|
dann = init_model(net=MNISTmodel_plain(), restore=None) |
|
|
|
|
|
|
|
# train dann model |
|
|
|
print("Training dann model") |
|
|
|
if not (dann.restored and params.dann_restore): |
|
|
|
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval) |
|
|
|
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device) |