|
|
@ -139,13 +139,14 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
|
}, { |
|
|
|
"params": model.discriminator.parameters() |
|
|
|
}] |
|
|
|
optimizer = optim.SGD(parameter_list, lr=0.01, momentum=0.9) |
|
|
|
optimizer = optim.SGD(parameter_list, lr=0.001, momentum=0.9) |
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
#################### |
|
|
|
# 2. train network # |
|
|
|
#################### |
|
|
|
try: |
|
|
|
global_step = 0 |
|
|
|
bestAcc = 0.0 |
|
|
|
for epoch in range(params.num_epochs): |
|
|
@ -230,7 +231,8 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
|
logger.add_scalar('tgt_test_loss', tgt_test_loss, global_step) |
|
|
|
logger.add_scalar('tgt_acc', tgt_acc, global_step) |
|
|
|
logger.add_scalar('tgt_acc_domain', tgt_acc_domain, global_step) |
|
|
|
|
|
|
|
except KeyboardInterrupt as ke: |
|
|
|
loggi.info('Saving the final weights before quitting') |
|
|
|
# save final model |
|
|
|
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") |
|
|
|
loggi.info('\n============ Summary ============= \n') |
|
|
|