|
|
@ -1,17 +1,19 @@ |
|
|
|
"""Train dann.""" |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
import torch.optim as optim |
|
|
|
|
|
|
|
from utils import make_variable, save_model |
|
|
|
import numpy as np |
|
|
|
from core.test import eval, eval_src |
|
|
|
from core.test import eval |
|
|
|
from utils.utils import save_model |
|
|
|
|
|
|
|
import torch.backends.cudnn as cudnn |
|
|
|
cudnn.benchmark = True |
|
|
|
|
|
|
|
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval): |
|
|
|
|
|
|
|
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device): |
|
|
|
"""Train dann.""" |
|
|
|
#################### |
|
|
|
# 1. setup network # |
|
|
@ -24,20 +26,23 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
|
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) |
|
|
|
else: |
|
|
|
print("training office task") |
|
|
|
parameter_list = [ |
|
|
|
{"params": model.features.parameters(), "lr": 0.001}, |
|
|
|
{"params": model.fc.parameters(), "lr": 0.001}, |
|
|
|
{"params": model.bottleneck.parameters()}, |
|
|
|
{"params": model.classifier.parameters()}, |
|
|
|
{"params": model.discriminator.parameters()} |
|
|
|
] |
|
|
|
parameter_list = [{ |
|
|
|
"params": model.features.parameters(), |
|
|
|
"lr": 0.001 |
|
|
|
}, { |
|
|
|
"params": model.fc.parameters(), |
|
|
|
"lr": 0.001 |
|
|
|
}, { |
|
|
|
"params": model.bottleneck.parameters() |
|
|
|
}, { |
|
|
|
"params": model.classifier.parameters() |
|
|
|
}, { |
|
|
|
"params": model.discriminator.parameters() |
|
|
|
}] |
|
|
|
optimizer = optim.SGD(parameter_list, lr=0.01, momentum=0.9) |
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
for p in model.parameters(): |
|
|
|
p.requires_grad = True |
|
|
|
|
|
|
|
#################### |
|
|
|
# 2. train network # |
|
|
|
#################### |
|
|
@ -50,9 +55,9 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
|
data_zip = enumerate(zip(src_data_loader, tgt_data_loader)) |
|
|
|
for step, ((images_src, class_src), (images_tgt, _)) in data_zip: |
|
|
|
|
|
|
|
p = float(step + epoch * len_dataloader) / params.num_epochs / len_dataloader |
|
|
|
p = float(step + epoch * len_dataloader) / \ |
|
|
|
params.num_epochs / len_dataloader |
|
|
|
alpha = 2. / (1. + np.exp(-10 * p)) - 1 |
|
|
|
alpha = 2*alpha |
|
|
|
|
|
|
|
if params.src_dataset == 'mnist' or params.tgt_dataset == 'mnist': |
|
|
|
adjust_learning_rate(optimizer, p) |
|
|
@ -62,13 +67,13 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
|
# prepare domain label |
|
|
|
size_src = len(images_src) |
|
|
|
size_tgt = len(images_tgt) |
|
|
|
label_src = make_variable(torch.zeros(size_src).long()) # source 0 |
|
|
|
label_tgt = make_variable(torch.ones(size_tgt).long()) # target 1 |
|
|
|
label_src = torch.zeros(size_src).long().to(device) # source 0 |
|
|
|
label_tgt = torch.ones(size_tgt).long().to(device) # target 1 |
|
|
|
|
|
|
|
# make images variable |
|
|
|
class_src = make_variable(class_src) |
|
|
|
images_src = make_variable(images_src) |
|
|
|
images_tgt = make_variable(images_tgt) |
|
|
|
class_src = class_src.to(device) |
|
|
|
images_src = images_src.to(device) |
|
|
|
images_tgt = images_tgt.to(device) |
|
|
|
|
|
|
|
# zero gradients for optimizer |
|
|
|
optimizer.zero_grad() |
|
|
@ -90,46 +95,44 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ |
|
|
|
|
|
|
|
# print step info |
|
|
|
if ((step + 1) % params.log_step == 0): |
|
|
|
print("Epoch [{:4d}/{}] Step [{:2d}/{}]: src_loss_class={:.6f}, src_loss_domain={:.6f}, tgt_loss_domain={:.6f}, loss={:.6f}" |
|
|
|
.format(epoch + 1, |
|
|
|
params.num_epochs, |
|
|
|
step + 1, |
|
|
|
len_dataloader, |
|
|
|
src_loss_class.data[0], |
|
|
|
src_loss_domain.data[0], |
|
|
|
tgt_loss_domain.data[0], |
|
|
|
loss.data[0])) |
|
|
|
|
|
|
|
# eval model on test set |
|
|
|
print( |
|
|
|
"Epoch [{:4d}/{}] Step [{:2d}/{}]: src_loss_class={:.6f}, src_loss_domain={:.6f}, tgt_loss_domain={:.6f}, loss={:.6f}" |
|
|
|
.format(epoch + 1, params.num_epochs, step + 1, len_dataloader, src_loss_class.data.item(), |
|
|
|
src_loss_domain.data.item(), tgt_loss_domain.data.item(), loss.data.item())) |
|
|
|
|
|
|
|
# eval model |
|
|
|
if ((epoch + 1) % params.eval_step == 0): |
|
|
|
print("eval on target domain") |
|
|
|
eval(model, tgt_data_loader) |
|
|
|
eval(model, tgt_data_loader, device, flag='target') |
|
|
|
print("eval on source domain") |
|
|
|
eval_src(model, src_data_loader) |
|
|
|
eval(model, src_data_loader, device, flag='source') |
|
|
|
|
|
|
|
# save model parameters |
|
|
|
if ((epoch + 1) % params.save_step == 0): |
|
|
|
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-{}.pt".format(epoch + 1)) |
|
|
|
save_model(model, params.model_root, |
|
|
|
params.src_dataset + '-' + params.tgt_dataset + "-dann-{}.pt".format(epoch + 1)) |
|
|
|
|
|
|
|
# save final model |
|
|
|
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") |
|
|
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def adjust_learning_rate(optimizer, p): |
|
|
|
lr_0 = 0.01 |
|
|
|
alpha = 10 |
|
|
|
beta = 0.75 |
|
|
|
lr = lr_0 / (1 + alpha*p) ** beta |
|
|
|
lr = lr_0 / (1 + alpha * p)**beta |
|
|
|
for param_group in optimizer.param_groups: |
|
|
|
param_group['lr'] = lr |
|
|
|
|
|
|
|
|
|
|
|
def adjust_learning_rate_office(optimizer, p): |
|
|
|
lr_0 = 0.001 |
|
|
|
alpha = 10 |
|
|
|
beta = 0.75 |
|
|
|
lr = lr_0 / (1 + alpha*p) ** beta |
|
|
|
lr = lr_0 / (1 + alpha * p)**beta |
|
|
|
for param_group in optimizer.param_groups[:2]: |
|
|
|
param_group['lr'] = lr |
|
|
|
for param_group in optimizer.param_groups[2:]: |
|
|
|
param_group['lr'] = 10*lr |
|
|
|
param_group['lr'] = 10 * lr |
|
|
|