Browse Source

Fixes for initial training

master
Fazil Altinel 4 years ago
parent
commit
eb194bff1c
  1. 55
      core/train.py
  2. 7
      datasets/mnist.py
  3. 0
      experiments/__init__.py
  4. 16
      experiments/mnist_mnistm.py

55
core/train.py

@ -10,7 +10,7 @@ from utils.utils import save_model
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger):
def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger=None):
"""Train dann."""
####################
# 1. setup network #
@ -60,7 +60,8 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e
lr = adjust_learning_rate(optimizer, p)
else:
lr = adjust_learning_rate_office(optimizer, p)
logger.add_scalar('lr', lr, global_step)
if not logger == None:
logger.add_scalar('lr', lr, global_step)
# prepare domain label
size_src = len(images_src)
@ -86,7 +87,8 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e
global_step += 1
# print step info
logger.add_scalar('loss', loss.item(), global_step)
if not logger == None:
logger.add_scalar('loss', loss.item(), global_step)
if ((step + 1) % params.log_step == 0):
print(
@ -96,8 +98,9 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e
if ((epoch + 1) % params.eval_step == 0):
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source')
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target')
logger.add_scalar('src_test_loss', src_test_loss, global_step)
logger.add_scalar('src_acc', src_acc, global_step)
if not logger == None:
logger.add_scalar('src_test_loss', src_test_loss, global_step)
logger.add_scalar('src_acc', src_acc, global_step)
# save model parameters
@ -110,7 +113,7 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e
return model
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger):
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger=None):
"""Train dann."""
####################
# 1. setup network #
@ -144,6 +147,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
# 2. train network #
####################
global_step = 0
bestAcc = 0.0
for epoch in range(params.num_epochs):
# set train state for Dropout and BN layers
model.train()
@ -160,7 +164,8 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
lr = adjust_learning_rate(optimizer, p)
else:
lr = adjust_learning_rate_office(optimizer, p)
logger.add_scalar('lr', lr, global_step)
if not logger == None:
logger.add_scalar('lr', lr, global_step)
# prepare domain label
size_src = len(images_src)
@ -196,10 +201,11 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
global_step += 1
# print step info
logger.add_scalar('src_loss_class', src_loss_class.item(), global_step)
logger.add_scalar('src_loss_domain', src_loss_domain.item(), global_step)
logger.add_scalar('tgt_loss_domain', tgt_loss_domain.item(), global_step)
logger.add_scalar('loss', loss.item(), global_step)
if not logger == None:
logger.add_scalar('src_loss_class', src_loss_class.item(), global_step)
logger.add_scalar('src_loss_domain', src_loss_domain.item(), global_step)
logger.add_scalar('tgt_loss_domain', tgt_loss_domain.item(), global_step)
logger.add_scalar('loss', loss.item(), global_step)
if ((step + 1) % params.log_step == 0):
print(
@ -211,21 +217,24 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
if ((epoch + 1) % params.eval_step == 0):
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target')
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source')
logger.add_scalar('src_test_loss', src_test_loss, global_step)
logger.add_scalar('src_acc', src_acc, global_step)
logger.add_scalar('src_acc_domain', src_acc_domain, global_step)
logger.add_scalar('tgt_test_loss', tgt_test_loss, global_step)
logger.add_scalar('tgt_acc', tgt_acc, global_step)
logger.add_scalar('tgt_acc_domain', tgt_acc_domain, global_step)
# save model parameters
if ((epoch + 1) % params.save_step == 0):
save_model(model, params.model_root,
params.src_dataset + '-' + params.tgt_dataset + "-dann-{}.pt".format(epoch + 1))
if tgt_acc > bestAcc:
bestAcc = tgt_acc
bestAccS = src_acc
save_model(model, params.model_root,
params.src_dataset + '-' + params.tgt_dataset + "-dann-best.pt")
if not logger == None:
logger.add_scalar('src_test_loss', src_test_loss, global_step)
logger.add_scalar('src_acc', src_acc, global_step)
logger.add_scalar('src_acc_domain', src_acc_domain, global_step)
logger.add_scalar('tgt_test_loss', tgt_test_loss, global_step)
logger.add_scalar('tgt_acc', tgt_acc, global_step)
logger.add_scalar('tgt_acc_domain', tgt_acc_domain, global_step)
# save final model
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt")
print('============ Summary ============= \n')
print('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS))
print('Accuracy of the %s dataset: %f' % (params.tgtc_dataset, bestAcc))
return model

7
datasets/mnist.py

@ -8,11 +8,11 @@ import os
def get_mnist(dataset_root, batch_size, train):
"""Get MNIST datasets loader."""
# image pre-processing
pre_process = transforms.Compose([transforms.Resize(32), # different img size settings for mnist(28) and svhn(32).
pre_process = transforms.Compose([transforms.Resize(28), # different img size settings for mnist(28) and svhn(32).
transforms.ToTensor(),
transforms.Normalize(
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5)
mean=(0.5),
std=(0.5)
)])
# datasets and data loader
@ -21,7 +21,6 @@ def get_mnist(dataset_root, batch_size, train):
transform=pre_process,
download=False)
mnist_data_loader = torch.utils.data.DataLoader(
dataset=mnist_dataset,
batch_size=batch_size,

0
experiments/__init__.py

16
experiments/mnist_mnistm.py

@ -1,8 +1,7 @@
import os
import sys
sys.path.append(os.path.abspath('.'))
import torch
sys.path.append('../')
from models.model import MNISTmodel, MNISTmodel_plain
from core.train import train_dann
from utils.utils import get_data_loader, init_model, init_random_seed
@ -10,8 +9,13 @@ from utils.utils import get_data_loader, init_model, init_random_seed
class Config(object):
# params for path
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets'))
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN'))
currentDir = os.path.dirname(os.path.realpath(__file__))
dataset_root = os.environ["DATASETDIR"]
model_root = os.path.join(currentDir, 'checkpoints')
finetune_flag = False
lr_adjust_flag = 'simple'
src_only_flag = False
# params for datasets and data loader
batch_size = 64
@ -53,6 +57,8 @@ class Config(object):
# params for optimizing models
lr = 2e-4
momentum = 0.0
weight_decay = 0.0
params = Config()
@ -70,7 +76,7 @@ tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root, param
tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=False)
# load dann model
dann = init_model(net=MNISTmodel_plain(), restore=None)
dann = init_model(net=MNISTmodel(), restore=None)
# train dann model
print("Training dann model")

Loading…
Cancel
Save