|
@ -108,9 +108,9 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
# eval model |
|
|
# eval model |
|
|
if ((epoch + 1) % params.eval_step == 0): |
|
|
if ((epoch + 1) % params.eval_step == 0): |
|
|
print("eval on target domain") |
|
|
print("eval on target domain") |
|
|
src_test_loss, src_acc, src_acc_domain = test(model, tgt_data_loader, device, flag='target') |
|
|
|
|
|
|
|
|
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader, device, flag='target') |
|
|
print("eval on source domain") |
|
|
print("eval on source domain") |
|
|
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, src_data_loader, device, flag='source') |
|
|
|
|
|
|
|
|
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source') |
|
|
logger.add_scalar('src_test_loss', src_test_loss, global_step) |
|
|
logger.add_scalar('src_test_loss', src_test_loss, global_step) |
|
|
logger.add_scalar('src_acc', src_acc, global_step) |
|
|
logger.add_scalar('src_acc', src_acc, global_step) |
|
|
logger.add_scalar('src_acc_domain', src_acc_domain, global_step) |
|
|
logger.add_scalar('src_acc_domain', src_acc_domain, global_step) |
|
|