diff --git a/core/dann.py b/core/dann.py index 8734446..6b56d58 100644 --- a/core/dann.py +++ b/core/dann.py @@ -108,9 +108,9 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ # eval model if ((epoch + 1) % params.eval_step == 0): print("eval on target domain") - src_test_loss, src_acc, src_acc_domain = test(model, tgt_data_loader, device, flag='target') + tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader, device, flag='target') print("eval on source domain") - tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, src_data_loader, device, flag='source') + src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source') logger.add_scalar('src_test_loss', src_test_loss, global_step) logger.add_scalar('src_acc', src_acc, global_step) logger.add_scalar('src_acc_domain', src_acc_domain, global_step)