From 65a3a30abe238cb854fb4330cbe759097f611e6f Mon Sep 17 00:00:00 2001 From: wogong Date: Thu, 29 Aug 2019 21:28:45 +0800 Subject: [PATCH] fix variables name --- core/dann.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/dann.py b/core/dann.py index 8734446..6b56d58 100644 --- a/core/dann.py +++ b/core/dann.py @@ -108,9 +108,9 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ # eval model if ((epoch + 1) % params.eval_step == 0): print("eval on target domain") - src_test_loss, src_acc, src_acc_domain = test(model, tgt_data_loader, device, flag='target') + tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader, device, flag='target') print("eval on source domain") - tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, src_data_loader, device, flag='source') + src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source') logger.add_scalar('src_test_loss', src_test_loss, global_step) logger.add_scalar('src_acc', src_acc, global_step) logger.add_scalar('src_acc_domain', src_acc_domain, global_step)