Browse Source

Logging improvements

master
Fazil Altinel 4 years ago
parent
commit
95d78532ce
  1. 2
      core/test.py
  2. 3
      core/train.py
  3. 203
      dann.ipynb

2
core/test.py

@ -39,6 +39,6 @@ def test(model, data_loader, device, loggi, flag):
acc = acc_ / n_total acc = acc_ / n_total
acc_domain = acc_domain_ / n_total acc_domain = acc_domain_ / n_total
loggi.info("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, {}/{}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_, n_total, acc_domain))
loggi.info("{}: Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, {}/{}, Avg Domain Accuracy = {:2%}".format(flag, loss, acc, acc_, n_total, acc_domain))
return loss, acc, acc_domain return loss, acc, acc_domain

3
core/train.py

@ -217,6 +217,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
if ((epoch + 1) % params.eval_step == 0): if ((epoch + 1) % params.eval_step == 0):
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, loggi, flag='target') tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, loggi, flag='target')
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, loggi, flag='source') src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, loggi, flag='source')
loggi.info('\n')
if tgt_acc > bestAcc: if tgt_acc > bestAcc:
bestAcc = tgt_acc bestAcc = tgt_acc
bestAccS = src_acc bestAccS = src_acc
@ -232,7 +233,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
# save final model # save final model
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt")
loggi.info('============ Summary ============= \n')
loggi.info('\n============ Summary ============= \n')
loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS)) loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS))
loggi.info('Accuracy of the %s dataset: %f' % (params.tgtc_dataset, bestAcc)) loggi.info('Accuracy of the %s dataset: %f' % (params.tgtc_dataset, bestAcc))

203
dann.ipynb

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save