A PyTorch implementation for paper Unsupervised Domain Adaptation by Backpropagation
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

135 lines
4.8 KiB

7 years ago
"""Train dann."""
import torch
import torch.nn as nn
import torch.optim as optim
from utils import make_variable, save_model
import numpy as np
from core.test import eval, eval_src
7 years ago
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval):
7 years ago
"""Train dann."""
####################
# 1. setup network #
####################
# setup criterion and optimizer
if params.src_dataset == 'mnist' or params.tgt_dataset == 'mnist':
print("training mnist task")
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
else:
print("training office task")
parameter_list = [
6 years ago
{"params": model.features.parameters(), "lr": 0.001},
{"params": model.fc.parameters(), "lr": 0.001},
{"params": model.bottleneck.parameters()},
{"params": model.classifier.parameters()},
{"params": model.discriminator.parameters()}
]
6 years ago
optimizer = optim.SGD(parameter_list, lr=0.01, momentum=0.9)
7 years ago
criterion = nn.CrossEntropyLoss()
7 years ago
for p in model.parameters():
7 years ago
p.requires_grad = True
####################
# 2. train network #
####################
for epoch in range(params.num_epochs):
# set train state for Dropout and BN layers
model.train()
7 years ago
# zip source and target data pair
len_dataloader = min(len(src_data_loader), len(tgt_data_loader))
data_zip = enumerate(zip(src_data_loader, tgt_data_loader))
for step, ((images_src, class_src), (images_tgt, _)) in data_zip:
p = float(step + epoch * len_dataloader) / params.num_epochs / len_dataloader
alpha = 2. / (1. + np.exp(-10 * p)) - 1
6 years ago
alpha = 2*alpha
if params.src_dataset == 'mnist' or params.tgt_dataset == 'mnist':
adjust_learning_rate(optimizer, p)
6 years ago
else:
adjust_learning_rate_office(optimizer, p)
7 years ago
# prepare domain label
size_src = len(images_src)
size_tgt = len(images_tgt)
label_src = make_variable(torch.zeros(size_src).long()) # source 0
label_tgt = make_variable(torch.ones(size_tgt).long()) # target 1
7 years ago
# make images variable
class_src = make_variable(class_src)
images_src = make_variable(images_src)
images_tgt = make_variable(images_tgt)
# zero gradients for optimizer
optimizer.zero_grad()
# train on source domain
src_class_output, src_domain_output = model(input_data=images_src, alpha=alpha)
7 years ago
src_loss_class = criterion(src_class_output, class_src)
src_loss_domain = criterion(src_domain_output, label_src)
# train on target domain
_, tgt_domain_output = model(input_data=images_tgt, alpha=alpha)
7 years ago
tgt_loss_domain = criterion(tgt_domain_output, label_tgt)
loss = src_loss_class + src_loss_domain + tgt_loss_domain
# optimize dann
loss.backward()
optimizer.step()
# print step info
if ((step + 1) % params.log_step == 0):
print("Epoch [{:4d}/{}] Step [{:2d}/{}]: src_loss_class={:.6f}, src_loss_domain={:.6f}, tgt_loss_domain={:.6f}, loss={:.6f}"
7 years ago
.format(epoch + 1,
params.num_epochs,
step + 1,
len_dataloader,
src_loss_class.data[0],
src_loss_domain.data[0],
tgt_loss_domain.data[0],
loss.data[0]))
# eval model on test set
if ((epoch + 1) % params.eval_step == 0):
print("eval on target domain")
eval(model, tgt_data_loader)
print("eval on source domain")
eval_src(model, src_data_loader)
7 years ago
# save model parameters
if ((epoch + 1) % params.save_step == 0):
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-{}.pt".format(epoch + 1))
7 years ago
# save final model
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt")
return model
def adjust_learning_rate(optimizer, p):
lr_0 = 0.01
alpha = 10
beta = 0.75
lr = lr_0 / (1 + alpha*p) ** beta
for param_group in optimizer.param_groups:
6 years ago
param_group['lr'] = lr
def adjust_learning_rate_office(optimizer, p):
lr_0 = 0.001
alpha = 10
beta = 0.75
lr = lr_0 / (1 + alpha*p) ** beta
for param_group in optimizer.param_groups[:2]:
param_group['lr'] = lr
for param_group in optimizer.param_groups[2:]:
param_group['lr'] = 10*lr