Browse Source

New parameters for office dataset training

master
Fazil Altinel 4 years ago
parent
commit
2b583df2d7
  1. 10
      core/train.py
  2. 8
      experiments/office.py

10
core/train.py

@ -144,10 +144,10 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
"lr": 0.0001 "lr": 0.0001
}, { }, {
"params": model.classifier.parameters(), "params": model.classifier.parameters(),
"lr": 0.0001
"lr": 0.001
}, { }, {
"params": model.discriminator.parameters(), "params": model.discriminator.parameters(),
"lr": 0.0001
"lr": 0.001
}] }]
optimizer = optim.SGD(parameter_list, momentum=0.9, weight_decay=1e-4) optimizer = optim.SGD(parameter_list, momentum=0.9, weight_decay=1e-4)
@ -248,6 +248,12 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
loggi.info('\n============ Summary ============= \n') loggi.info('\n============ Summary ============= \n')
loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS)) loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS))
loggi.info('Accuracy of the %s dataset: %f' % (params.tgt_dataset, bestAcc)) loggi.info('Accuracy of the %s dataset: %f' % (params.tgt_dataset, bestAcc))
loggi.info('Saving the final weights')
# save final model
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt")
loggi.info('\n============ Summary ============= \n')
loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS))
loggi.info('Accuracy of the %s dataset: %f' % (params.tgt_dataset, bestAcc))
return model return model

8
experiments/office.py

@ -28,12 +28,16 @@ class Config(object):
batch_size = 32 batch_size = 32
# params for source dataset # params for source dataset
src_dataset = "amazon31"
# src_dataset = "amazon31"
# src_dataset = "dslr31"
src_dataset = "webcam31"
src_model_trained = True src_model_trained = True
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt') src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt')
# params for target dataset # params for target dataset
tgt_dataset = "webcam31"
# tgt_dataset = "webcam31"
# tgt_dataset = "dslr31"
tgt_dataset = "amazon31"
tgt_model_trained = True tgt_model_trained = True
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt') dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt')

Loading…
Cancel
Save