|
|
@ -18,12 +18,20 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
|
#################### |
|
|
|
|
|
|
|
# setup criterion and optimizer |
|
|
|
# parameter_list = [ |
|
|
|
# # {"params": model.feature.parameters(), "lr": 1e-5}, |
|
|
|
# # {"params": model.classifier.parameters(), "lr": 1e-4}, |
|
|
|
# # {"params": model.discriminator.parameters(), "lr": 1e-4} |
|
|
|
# # ] |
|
|
|
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) |
|
|
|
|
|
|
|
if params.src_dataset == 'mnist' or params.tgt_dataset == 'mnist': |
|
|
|
print("training mnist task") |
|
|
|
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) |
|
|
|
else: |
|
|
|
print("training office task") |
|
|
|
parameter_list = [ |
|
|
|
{"params": model.features.parameters(), "lr": 1e-5}, |
|
|
|
{"params": model.fc.parameters(), "lr": 1e-5}, |
|
|
|
{"params": model.bottleneck.parameters(), "lr": 1e-4}, |
|
|
|
{"params": model.classifier.parameters(), "lr": 1e-4}, |
|
|
|
{"params": model.discriminator.parameters(), "lr": 1e-4} |
|
|
|
] |
|
|
|
optimizer = optim.SGD(parameter_list) |
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
|
@ -44,7 +52,9 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
|
|
|
|
|
p = float(step + epoch * len_dataloader) / params.num_epochs / len_dataloader |
|
|
|
alpha = 2. / (1. + np.exp(-10 * p)) - 1 |
|
|
|
adjust_learning_rate(optimizer, p) |
|
|
|
if params.src_dataset == 'mnist' or params.tgt_dataset == 'mnist': |
|
|
|
print("training mnist task") |
|
|
|
adjust_learning_rate(optimizer, p) |
|
|
|
|
|
|
|
# prepare domain label |
|
|
|
size_src = len(images_src) |
|
|
@ -90,7 +100,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
|
# eval model on test set |
|
|
|
if ((epoch + 1) % params.eval_step == 0): |
|
|
|
print("eval on target domain") |
|
|
|
eval(model, tgt_data_loader_eval) |
|
|
|
eval(model, tgt_data_loader) |
|
|
|
print("eval on source domain") |
|
|
|
eval_src(model, src_data_loader) |
|
|
|
|
|
|
|