diff --git a/core/train.py b/core/train.py index 9ba407a..27fb4f4 100644 --- a/core/train.py +++ b/core/train.py @@ -10,7 +10,7 @@ from utils.utils import save_model import torch.backends.cudnn as cudnn cudnn.benchmark = True -def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger): +def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger=None): """Train dann.""" #################### # 1. setup network # @@ -60,7 +60,8 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e lr = adjust_learning_rate(optimizer, p) else: lr = adjust_learning_rate_office(optimizer, p) - logger.add_scalar('lr', lr, global_step) + if not logger == None: + logger.add_scalar('lr', lr, global_step) # prepare domain label size_src = len(images_src) @@ -86,7 +87,8 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e global_step += 1 # print step info - logger.add_scalar('loss', loss.item(), global_step) + if not logger == None: + logger.add_scalar('loss', loss.item(), global_step) if ((step + 1) % params.log_step == 0): print( @@ -96,8 +98,9 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e if ((epoch + 1) % params.eval_step == 0): src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source') tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target') - logger.add_scalar('src_test_loss', src_test_loss, global_step) - logger.add_scalar('src_acc', src_acc, global_step) + if not logger == None: + logger.add_scalar('src_test_loss', src_test_loss, global_step) + logger.add_scalar('src_acc', src_acc, global_step) # save model parameters @@ -110,7 +113,7 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e return model -def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger): +def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger=None): """Train dann.""" #################### # 1. setup network # @@ -144,6 +147,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ # 2. train network # #################### global_step = 0 + bestAcc = 0.0 for epoch in range(params.num_epochs): # set train state for Dropout and BN layers model.train() @@ -160,7 +164,8 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ lr = adjust_learning_rate(optimizer, p) else: lr = adjust_learning_rate_office(optimizer, p) - logger.add_scalar('lr', lr, global_step) + if not logger == None: + logger.add_scalar('lr', lr, global_step) # prepare domain label size_src = len(images_src) @@ -196,10 +201,11 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ global_step += 1 # print step info - logger.add_scalar('src_loss_class', src_loss_class.item(), global_step) - logger.add_scalar('src_loss_domain', src_loss_domain.item(), global_step) - logger.add_scalar('tgt_loss_domain', tgt_loss_domain.item(), global_step) - logger.add_scalar('loss', loss.item(), global_step) + if not logger == None: + logger.add_scalar('src_loss_class', src_loss_class.item(), global_step) + logger.add_scalar('src_loss_domain', src_loss_domain.item(), global_step) + logger.add_scalar('tgt_loss_domain', tgt_loss_domain.item(), global_step) + logger.add_scalar('loss', loss.item(), global_step) if ((step + 1) % params.log_step == 0): print( @@ -211,21 +217,24 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ if ((epoch + 1) % params.eval_step == 0): tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target') src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source') - logger.add_scalar('src_test_loss', src_test_loss, global_step) - logger.add_scalar('src_acc', src_acc, global_step) - logger.add_scalar('src_acc_domain', src_acc_domain, global_step) - logger.add_scalar('tgt_test_loss', tgt_test_loss, global_step) - logger.add_scalar('tgt_acc', tgt_acc, global_step) - logger.add_scalar('tgt_acc_domain', tgt_acc_domain, global_step) - - - # save model parameters - if ((epoch + 1) % params.save_step == 0): - save_model(model, params.model_root, - params.src_dataset + '-' + params.tgt_dataset + "-dann-{}.pt".format(epoch + 1)) + if tgt_acc > bestAcc: + bestAcc = tgt_acc + bestAccS = src_acc + save_model(model, params.model_root, + params.src_dataset + '-' + params.tgt_dataset + "-dann-best.pt") + if not logger == None: + logger.add_scalar('src_test_loss', src_test_loss, global_step) + logger.add_scalar('src_acc', src_acc, global_step) + logger.add_scalar('src_acc_domain', src_acc_domain, global_step) + logger.add_scalar('tgt_test_loss', tgt_test_loss, global_step) + logger.add_scalar('tgt_acc', tgt_acc, global_step) + logger.add_scalar('tgt_acc_domain', tgt_acc_domain, global_step) # save final model save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") + print('============ Summary ============= \n') + print('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS)) + print('Accuracy of the %s dataset: %f' % (params.tgtc_dataset, bestAcc)) return model diff --git a/datasets/mnist.py b/datasets/mnist.py index 03e64ca..1bbc192 100644 --- a/datasets/mnist.py +++ b/datasets/mnist.py @@ -8,11 +8,11 @@ import os def get_mnist(dataset_root, batch_size, train): """Get MNIST datasets loader.""" # image pre-processing - pre_process = transforms.Compose([transforms.Resize(32), # different img size settings for mnist(28) and svhn(32). + pre_process = transforms.Compose([transforms.Resize(28), # different img size settings for mnist(28) and svhn(32). transforms.ToTensor(), transforms.Normalize( - mean=(0.5, 0.5, 0.5), - std=(0.5, 0.5, 0.5) + mean=(0.5), + std=(0.5) )]) # datasets and data loader @@ -21,7 +21,6 @@ def get_mnist(dataset_root, batch_size, train): transform=pre_process, download=False) - mnist_data_loader = torch.utils.data.DataLoader( dataset=mnist_dataset, batch_size=batch_size, diff --git a/experiments/__init__.py b/experiments/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/mnist_mnistm.py b/experiments/mnist_mnistm.py index b523264..37985c4 100644 --- a/experiments/mnist_mnistm.py +++ b/experiments/mnist_mnistm.py @@ -1,8 +1,7 @@ import os import sys - +sys.path.append(os.path.abspath('.')) import torch -sys.path.append('../') from models.model import MNISTmodel, MNISTmodel_plain from core.train import train_dann from utils.utils import get_data_loader, init_model, init_random_seed @@ -10,8 +9,13 @@ from utils.utils import get_data_loader, init_model, init_random_seed class Config(object): # params for path - dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) - model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN')) + currentDir = os.path.dirname(os.path.realpath(__file__)) + dataset_root = os.environ["DATASETDIR"] + model_root = os.path.join(currentDir, 'checkpoints') + + finetune_flag = False + lr_adjust_flag = 'simple' + src_only_flag = False # params for datasets and data loader batch_size = 64 @@ -53,6 +57,8 @@ class Config(object): # params for optimizing models lr = 2e-4 + momentum = 0.0 + weight_decay = 0.0 params = Config() @@ -70,7 +76,7 @@ tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root, param tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=False) # load dann model -dann = init_model(net=MNISTmodel_plain(), restore=None) +dann = init_model(net=MNISTmodel(), restore=None) # train dann model print("Training dann model")