diff --git a/core/train.py b/core/train.py index 1659a1c..9c88076 100644 --- a/core/train.py +++ b/core/train.py @@ -144,10 +144,10 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ "lr": 0.0001 }, { "params": model.classifier.parameters(), - "lr": 0.0001 + "lr": 0.001 }, { "params": model.discriminator.parameters(), - "lr": 0.0001 + "lr": 0.001 }] optimizer = optim.SGD(parameter_list, momentum=0.9, weight_decay=1e-4) @@ -248,6 +248,12 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ loggi.info('\n============ Summary ============= \n') loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS)) loggi.info('Accuracy of the %s dataset: %f' % (params.tgt_dataset, bestAcc)) + loggi.info('Saving the final weights') + # save final model + save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") + loggi.info('\n============ Summary ============= \n') + loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS)) + loggi.info('Accuracy of the %s dataset: %f' % (params.tgt_dataset, bestAcc)) return model diff --git a/experiments/office.py b/experiments/office.py index ca7dd9f..6596c8d 100644 --- a/experiments/office.py +++ b/experiments/office.py @@ -28,12 +28,16 @@ class Config(object): batch_size = 32 # params for source dataset - src_dataset = "amazon31" + # src_dataset = "amazon31" + # src_dataset = "dslr31" + src_dataset = "webcam31" src_model_trained = True src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt') # params for target dataset - tgt_dataset = "webcam31" + # tgt_dataset = "webcam31" + # tgt_dataset = "dslr31" + tgt_dataset = "amazon31" tgt_model_trained = True dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt')