|
@ -217,6 +217,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
if ((epoch + 1) % params.eval_step == 0): |
|
|
if ((epoch + 1) % params.eval_step == 0): |
|
|
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, loggi, flag='target') |
|
|
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, loggi, flag='target') |
|
|
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, loggi, flag='source') |
|
|
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, loggi, flag='source') |
|
|
|
|
|
loggi.info('\n') |
|
|
if tgt_acc > bestAcc: |
|
|
if tgt_acc > bestAcc: |
|
|
bestAcc = tgt_acc |
|
|
bestAcc = tgt_acc |
|
|
bestAccS = src_acc |
|
|
bestAccS = src_acc |
|
@ -232,7 +233,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
|
|
|
|
|
|
# save final model |
|
|
# save final model |
|
|
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") |
|
|
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") |
|
|
loggi.info('============ Summary ============= \n') |
|
|
|
|
|
|
|
|
loggi.info('\n============ Summary ============= \n') |
|
|
loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS)) |
|
|
loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS)) |
|
|
loggi.info('Accuracy of the %s dataset: %f' % (params.tgtc_dataset, bestAcc)) |
|
|
loggi.info('Accuracy of the %s dataset: %f' % (params.tgtc_dataset, bestAcc)) |
|
|
|
|
|
|
|
|