|
@ -10,7 +10,7 @@ from utils.utils import save_model |
|
|
import torch.backends.cudnn as cudnn |
|
|
import torch.backends.cudnn as cudnn |
|
|
cudnn.benchmark = True |
|
|
cudnn.benchmark = True |
|
|
|
|
|
|
|
|
def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger): |
|
|
|
|
|
|
|
|
def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger=None): |
|
|
"""Train dann.""" |
|
|
"""Train dann.""" |
|
|
#################### |
|
|
#################### |
|
|
# 1. setup network # |
|
|
# 1. setup network # |
|
@ -60,6 +60,7 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e |
|
|
lr = adjust_learning_rate(optimizer, p) |
|
|
lr = adjust_learning_rate(optimizer, p) |
|
|
else: |
|
|
else: |
|
|
lr = adjust_learning_rate_office(optimizer, p) |
|
|
lr = adjust_learning_rate_office(optimizer, p) |
|
|
|
|
|
if not logger == None: |
|
|
logger.add_scalar('lr', lr, global_step) |
|
|
logger.add_scalar('lr', lr, global_step) |
|
|
|
|
|
|
|
|
# prepare domain label |
|
|
# prepare domain label |
|
@ -86,6 +87,7 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e |
|
|
global_step += 1 |
|
|
global_step += 1 |
|
|
|
|
|
|
|
|
# print step info |
|
|
# print step info |
|
|
|
|
|
if not logger == None: |
|
|
logger.add_scalar('loss', loss.item(), global_step) |
|
|
logger.add_scalar('loss', loss.item(), global_step) |
|
|
|
|
|
|
|
|
if ((step + 1) % params.log_step == 0): |
|
|
if ((step + 1) % params.log_step == 0): |
|
@ -96,6 +98,7 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e |
|
|
if ((epoch + 1) % params.eval_step == 0): |
|
|
if ((epoch + 1) % params.eval_step == 0): |
|
|
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source') |
|
|
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source') |
|
|
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target') |
|
|
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target') |
|
|
|
|
|
if not logger == None: |
|
|
logger.add_scalar('src_test_loss', src_test_loss, global_step) |
|
|
logger.add_scalar('src_test_loss', src_test_loss, global_step) |
|
|
logger.add_scalar('src_acc', src_acc, global_step) |
|
|
logger.add_scalar('src_acc', src_acc, global_step) |
|
|
|
|
|
|
|
@ -110,7 +113,7 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e |
|
|
|
|
|
|
|
|
return model |
|
|
return model |
|
|
|
|
|
|
|
|
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger): |
|
|
|
|
|
|
|
|
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger=None): |
|
|
"""Train dann.""" |
|
|
"""Train dann.""" |
|
|
#################### |
|
|
#################### |
|
|
# 1. setup network # |
|
|
# 1. setup network # |
|
@ -144,6 +147,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
# 2. train network # |
|
|
# 2. train network # |
|
|
#################### |
|
|
#################### |
|
|
global_step = 0 |
|
|
global_step = 0 |
|
|
|
|
|
bestAcc = 0.0 |
|
|
for epoch in range(params.num_epochs): |
|
|
for epoch in range(params.num_epochs): |
|
|
# set train state for Dropout and BN layers |
|
|
# set train state for Dropout and BN layers |
|
|
model.train() |
|
|
model.train() |
|
@ -160,6 +164,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
lr = adjust_learning_rate(optimizer, p) |
|
|
lr = adjust_learning_rate(optimizer, p) |
|
|
else: |
|
|
else: |
|
|
lr = adjust_learning_rate_office(optimizer, p) |
|
|
lr = adjust_learning_rate_office(optimizer, p) |
|
|
|
|
|
if not logger == None: |
|
|
logger.add_scalar('lr', lr, global_step) |
|
|
logger.add_scalar('lr', lr, global_step) |
|
|
|
|
|
|
|
|
# prepare domain label |
|
|
# prepare domain label |
|
@ -196,6 +201,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
global_step += 1 |
|
|
global_step += 1 |
|
|
|
|
|
|
|
|
# print step info |
|
|
# print step info |
|
|
|
|
|
if not logger == None: |
|
|
logger.add_scalar('src_loss_class', src_loss_class.item(), global_step) |
|
|
logger.add_scalar('src_loss_class', src_loss_class.item(), global_step) |
|
|
logger.add_scalar('src_loss_domain', src_loss_domain.item(), global_step) |
|
|
logger.add_scalar('src_loss_domain', src_loss_domain.item(), global_step) |
|
|
logger.add_scalar('tgt_loss_domain', tgt_loss_domain.item(), global_step) |
|
|
logger.add_scalar('tgt_loss_domain', tgt_loss_domain.item(), global_step) |
|
@ -211,6 +217,12 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
if ((epoch + 1) % params.eval_step == 0): |
|
|
if ((epoch + 1) % params.eval_step == 0): |
|
|
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target') |
|
|
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target') |
|
|
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source') |
|
|
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source') |
|
|
|
|
|
if tgt_acc > bestAcc: |
|
|
|
|
|
bestAcc = tgt_acc |
|
|
|
|
|
bestAccS = src_acc |
|
|
|
|
|
save_model(model, params.model_root, |
|
|
|
|
|
params.src_dataset + '-' + params.tgt_dataset + "-dann-best.pt") |
|
|
|
|
|
if not logger == None: |
|
|
logger.add_scalar('src_test_loss', src_test_loss, global_step) |
|
|
logger.add_scalar('src_test_loss', src_test_loss, global_step) |
|
|
logger.add_scalar('src_acc', src_acc, global_step) |
|
|
logger.add_scalar('src_acc', src_acc, global_step) |
|
|
logger.add_scalar('src_acc_domain', src_acc_domain, global_step) |
|
|
logger.add_scalar('src_acc_domain', src_acc_domain, global_step) |
|
@ -218,14 +230,11 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
logger.add_scalar('tgt_acc', tgt_acc, global_step) |
|
|
logger.add_scalar('tgt_acc', tgt_acc, global_step) |
|
|
logger.add_scalar('tgt_acc_domain', tgt_acc_domain, global_step) |
|
|
logger.add_scalar('tgt_acc_domain', tgt_acc_domain, global_step) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# save model parameters |
|
|
|
|
|
if ((epoch + 1) % params.save_step == 0): |
|
|
|
|
|
save_model(model, params.model_root, |
|
|
|
|
|
params.src_dataset + '-' + params.tgt_dataset + "-dann-{}.pt".format(epoch + 1)) |
|
|
|
|
|
|
|
|
|
|
|
# save final model |
|
|
# save final model |
|
|
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") |
|
|
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") |
|
|
|
|
|
print('============ Summary ============= \n') |
|
|
|
|
|
print('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS)) |
|
|
|
|
|
print('Accuracy of the %s dataset: %f' % (params.tgtc_dataset, bestAcc)) |
|
|
|
|
|
|
|
|
return model |
|
|
return model |
|
|
|
|
|
|
|
|