Browse Source

fix variables name

master
wogong 5 years ago
parent
commit
65a3a30abe
  1. 4
      core/dann.py

4
core/dann.py

@ -108,9 +108,9 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
# eval model # eval model
if ((epoch + 1) % params.eval_step == 0): if ((epoch + 1) % params.eval_step == 0):
print("eval on target domain") print("eval on target domain")
src_test_loss, src_acc, src_acc_domain = test(model, tgt_data_loader, device, flag='target')
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader, device, flag='target')
print("eval on source domain") print("eval on source domain")
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, src_data_loader, device, flag='source')
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source')
logger.add_scalar('src_test_loss', src_test_loss, global_step) logger.add_scalar('src_test_loss', src_test_loss, global_step)
logger.add_scalar('src_acc', src_acc, global_step) logger.add_scalar('src_acc', src_acc, global_step)
logger.add_scalar('src_acc_domain', src_acc_domain, global_step) logger.add_scalar('src_acc_domain', src_acc_domain, global_step)

Loading…
Cancel
Save