Browse Source

Fixes for initial training

master
Fazil Altinel 4 years ago
parent
commit
eb194bff1c
  1. 25
      core/train.py
  2. 7
      datasets/mnist.py
  3. 0
      experiments/__init__.py
  4. 16
      experiments/mnist_mnistm.py

25
core/train.py

@ -10,7 +10,7 @@ from utils.utils import save_model
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
cudnn.benchmark = True cudnn.benchmark = True
def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger):
def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger=None):
"""Train dann.""" """Train dann."""
#################### ####################
# 1. setup network # # 1. setup network #
@ -60,6 +60,7 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e
lr = adjust_learning_rate(optimizer, p) lr = adjust_learning_rate(optimizer, p)
else: else:
lr = adjust_learning_rate_office(optimizer, p) lr = adjust_learning_rate_office(optimizer, p)
if not logger == None:
logger.add_scalar('lr', lr, global_step) logger.add_scalar('lr', lr, global_step)
# prepare domain label # prepare domain label
@ -86,6 +87,7 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e
global_step += 1 global_step += 1
# print step info # print step info
if not logger == None:
logger.add_scalar('loss', loss.item(), global_step) logger.add_scalar('loss', loss.item(), global_step)
if ((step + 1) % params.log_step == 0): if ((step + 1) % params.log_step == 0):
@ -96,6 +98,7 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e
if ((epoch + 1) % params.eval_step == 0): if ((epoch + 1) % params.eval_step == 0):
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source') src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source')
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target') tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target')
if not logger == None:
logger.add_scalar('src_test_loss', src_test_loss, global_step) logger.add_scalar('src_test_loss', src_test_loss, global_step)
logger.add_scalar('src_acc', src_acc, global_step) logger.add_scalar('src_acc', src_acc, global_step)
@ -110,7 +113,7 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e
return model return model
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger):
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger=None):
"""Train dann.""" """Train dann."""
#################### ####################
# 1. setup network # # 1. setup network #
@ -144,6 +147,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
# 2. train network # # 2. train network #
#################### ####################
global_step = 0 global_step = 0
bestAcc = 0.0
for epoch in range(params.num_epochs): for epoch in range(params.num_epochs):
# set train state for Dropout and BN layers # set train state for Dropout and BN layers
model.train() model.train()
@ -160,6 +164,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
lr = adjust_learning_rate(optimizer, p) lr = adjust_learning_rate(optimizer, p)
else: else:
lr = adjust_learning_rate_office(optimizer, p) lr = adjust_learning_rate_office(optimizer, p)
if not logger == None:
logger.add_scalar('lr', lr, global_step) logger.add_scalar('lr', lr, global_step)
# prepare domain label # prepare domain label
@ -196,6 +201,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
global_step += 1 global_step += 1
# print step info # print step info
if not logger == None:
logger.add_scalar('src_loss_class', src_loss_class.item(), global_step) logger.add_scalar('src_loss_class', src_loss_class.item(), global_step)
logger.add_scalar('src_loss_domain', src_loss_domain.item(), global_step) logger.add_scalar('src_loss_domain', src_loss_domain.item(), global_step)
logger.add_scalar('tgt_loss_domain', tgt_loss_domain.item(), global_step) logger.add_scalar('tgt_loss_domain', tgt_loss_domain.item(), global_step)
@ -211,6 +217,12 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
if ((epoch + 1) % params.eval_step == 0): if ((epoch + 1) % params.eval_step == 0):
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target') tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target')
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source') src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source')
if tgt_acc > bestAcc:
bestAcc = tgt_acc
bestAccS = src_acc
save_model(model, params.model_root,
params.src_dataset + '-' + params.tgt_dataset + "-dann-best.pt")
if not logger == None:
logger.add_scalar('src_test_loss', src_test_loss, global_step) logger.add_scalar('src_test_loss', src_test_loss, global_step)
logger.add_scalar('src_acc', src_acc, global_step) logger.add_scalar('src_acc', src_acc, global_step)
logger.add_scalar('src_acc_domain', src_acc_domain, global_step) logger.add_scalar('src_acc_domain', src_acc_domain, global_step)
@ -218,14 +230,11 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
logger.add_scalar('tgt_acc', tgt_acc, global_step) logger.add_scalar('tgt_acc', tgt_acc, global_step)
logger.add_scalar('tgt_acc_domain', tgt_acc_domain, global_step) logger.add_scalar('tgt_acc_domain', tgt_acc_domain, global_step)
# save model parameters
if ((epoch + 1) % params.save_step == 0):
save_model(model, params.model_root,
params.src_dataset + '-' + params.tgt_dataset + "-dann-{}.pt".format(epoch + 1))
# save final model # save final model
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt")
print('============ Summary ============= \n')
print('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS))
print('Accuracy of the %s dataset: %f' % (params.tgtc_dataset, bestAcc))
return model return model

7
datasets/mnist.py

@ -8,11 +8,11 @@ import os
def get_mnist(dataset_root, batch_size, train): def get_mnist(dataset_root, batch_size, train):
"""Get MNIST datasets loader.""" """Get MNIST datasets loader."""
# image pre-processing # image pre-processing
pre_process = transforms.Compose([transforms.Resize(32), # different img size settings for mnist(28) and svhn(32).
pre_process = transforms.Compose([transforms.Resize(28), # different img size settings for mnist(28) and svhn(32).
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5)
mean=(0.5),
std=(0.5)
)]) )])
# datasets and data loader # datasets and data loader
@ -21,7 +21,6 @@ def get_mnist(dataset_root, batch_size, train):
transform=pre_process, transform=pre_process,
download=False) download=False)
mnist_data_loader = torch.utils.data.DataLoader( mnist_data_loader = torch.utils.data.DataLoader(
dataset=mnist_dataset, dataset=mnist_dataset,
batch_size=batch_size, batch_size=batch_size,

0
experiments/__init__.py

16
experiments/mnist_mnistm.py

@ -1,8 +1,7 @@
import os import os
import sys import sys
sys.path.append(os.path.abspath('.'))
import torch import torch
sys.path.append('../')
from models.model import MNISTmodel, MNISTmodel_plain from models.model import MNISTmodel, MNISTmodel_plain
from core.train import train_dann from core.train import train_dann
from utils.utils import get_data_loader, init_model, init_random_seed from utils.utils import get_data_loader, init_model, init_random_seed
@ -10,8 +9,13 @@ from utils.utils import get_data_loader, init_model, init_random_seed
class Config(object): class Config(object):
# params for path # params for path
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets'))
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN'))
currentDir = os.path.dirname(os.path.realpath(__file__))
dataset_root = os.environ["DATASETDIR"]
model_root = os.path.join(currentDir, 'checkpoints')
finetune_flag = False
lr_adjust_flag = 'simple'
src_only_flag = False
# params for datasets and data loader # params for datasets and data loader
batch_size = 64 batch_size = 64
@ -53,6 +57,8 @@ class Config(object):
# params for optimizing models # params for optimizing models
lr = 2e-4 lr = 2e-4
momentum = 0.0
weight_decay = 0.0
params = Config() params = Config()
@ -70,7 +76,7 @@ tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root, param
tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=False) tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=False)
# load dann model # load dann model
dann = init_model(net=MNISTmodel_plain(), restore=None)
dann = init_model(net=MNISTmodel(), restore=None)
# train dann model # train dann model
print("Training dann model") print("Training dann model")

Loading…
Cancel
Save