|
@ -144,10 +144,10 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
"lr": 0.0001 |
|
|
"lr": 0.0001 |
|
|
}, { |
|
|
}, { |
|
|
"params": model.classifier.parameters(), |
|
|
"params": model.classifier.parameters(), |
|
|
"lr": 0.0001 |
|
|
|
|
|
|
|
|
"lr": 0.001 |
|
|
}, { |
|
|
}, { |
|
|
"params": model.discriminator.parameters(), |
|
|
"params": model.discriminator.parameters(), |
|
|
"lr": 0.0001 |
|
|
|
|
|
|
|
|
"lr": 0.001 |
|
|
}] |
|
|
}] |
|
|
optimizer = optim.SGD(parameter_list, momentum=0.9, weight_decay=1e-4) |
|
|
optimizer = optim.SGD(parameter_list, momentum=0.9, weight_decay=1e-4) |
|
|
|
|
|
|
|
@ -248,6 +248,12 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
loggi.info('\n============ Summary ============= \n') |
|
|
loggi.info('\n============ Summary ============= \n') |
|
|
loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS)) |
|
|
loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS)) |
|
|
loggi.info('Accuracy of the %s dataset: %f' % (params.tgt_dataset, bestAcc)) |
|
|
loggi.info('Accuracy of the %s dataset: %f' % (params.tgt_dataset, bestAcc)) |
|
|
|
|
|
loggi.info('Saving the final weights') |
|
|
|
|
|
# save final model |
|
|
|
|
|
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") |
|
|
|
|
|
loggi.info('\n============ Summary ============= \n') |
|
|
|
|
|
loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS)) |
|
|
|
|
|
loggi.info('Accuracy of the %s dataset: %f' % (params.tgt_dataset, bestAcc)) |
|
|
|
|
|
|
|
|
return model |
|
|
return model |
|
|
|
|
|
|
|
|