You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
96 lines
3.2 KiB
96 lines
3.2 KiB
7 years ago
|
"""Train dann."""
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.optim as optim
|
||
|
|
||
|
import params
|
||
|
from utils import make_variable, save_model
|
||
|
import numpy as np
|
||
|
from core.test import eval
|
||
|
|
||
|
import torch.backends.cudnn as cudnn
|
||
|
cudnn.benchmark = True
|
||
|
|
||
|
def train_dann(dann, src_data_loader, tgt_data_loader, tgt_data_loader_eval):
|
||
|
"""Train dann."""
|
||
|
####################
|
||
|
# 1. setup network #
|
||
|
####################
|
||
|
|
||
|
# set train state for Dropout and BN layers
|
||
|
dann.train()
|
||
|
|
||
|
# setup criterion and optimizer
|
||
|
optimizer = optim.Adam(dann.parameters(), lr=params.lr)
|
||
|
|
||
|
criterion = nn.NLLLoss()
|
||
|
|
||
|
for p in dann.parameters():
|
||
|
p.requires_grad = True
|
||
|
|
||
|
####################
|
||
|
# 2. train network #
|
||
|
####################
|
||
|
|
||
|
# prepare domain label
|
||
|
label_src = make_variable(torch.zeros(params.batch_size).long()) # source 0
|
||
|
label_tgt = make_variable(torch.ones(params.batch_size).long()) # target 1
|
||
|
|
||
|
for epoch in range(params.num_epochs):
|
||
|
# zip source and target data pair
|
||
|
len_dataloader = min(len(src_data_loader), len(tgt_data_loader))
|
||
|
data_zip = enumerate(zip(src_data_loader, tgt_data_loader))
|
||
|
for step, ((images_src, class_src), (images_tgt, _)) in data_zip:
|
||
|
|
||
|
p = float(step + epoch * len_dataloader) / params.num_epochs / len_dataloader
|
||
|
alpha = 2. / (1. + np.exp(-10 * p)) - 1
|
||
|
|
||
|
# make images variable
|
||
|
class_src = make_variable(class_src)
|
||
|
images_src = make_variable(images_src)
|
||
|
images_tgt = make_variable(images_tgt)
|
||
|
|
||
|
# zero gradients for optimizer
|
||
|
optimizer.zero_grad()
|
||
|
|
||
|
# train on source domain
|
||
|
src_class_output, src_domain_output = dann(input_data=images_src, alpha=alpha)
|
||
|
src_loss_class = criterion(src_class_output, class_src)
|
||
|
src_loss_domain = criterion(src_domain_output, label_src)
|
||
|
|
||
|
# train on target domain
|
||
|
_, tgt_domain_output = dann(input_data=images_tgt, alpha=alpha)
|
||
|
tgt_loss_domain = criterion(tgt_domain_output, label_tgt)
|
||
|
|
||
|
loss = src_loss_class + src_loss_domain + tgt_loss_domain
|
||
|
|
||
|
# optimize dann
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
|
||
|
# print step info
|
||
|
if ((step + 1) % params.log_step == 0):
|
||
|
print("Epoch [{}/{}] Step [{}/{}]: src_loss_class={}, src_loss_domain={}, tgt_loss_domain={}, loss={}"
|
||
|
.format(epoch + 1,
|
||
|
params.num_epochs,
|
||
|
step + 1,
|
||
|
len_dataloader,
|
||
|
src_loss_class.data[0],
|
||
|
src_loss_domain.data[0],
|
||
|
tgt_loss_domain.data[0],
|
||
|
loss.data[0]))
|
||
|
|
||
|
# eval model on test set
|
||
|
if ((epoch + 1) % params.eval_step == 0):
|
||
|
eval(dann, tgt_data_loader_eval)
|
||
|
dann.train()
|
||
|
|
||
|
# save model parameters
|
||
|
if ((epoch + 1) % params.save_step == 0):
|
||
|
save_model(dann, params.src_dataset + '-' + params.tgt_dataset + "-dann-{}.pt".format(epoch + 1))
|
||
|
|
||
|
# save final model
|
||
|
save_model(dann, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt")
|
||
|
|
||
|
return dann
|