Browse Source

Save the final model with KeyboardInterrupt

master
Fazil Altinel 4 years ago
parent
commit
09908685e0
  1. 184
      core/train.py

184
core/train.py

@ -139,103 +139,105 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
}, {
"params": model.discriminator.parameters()
}]
optimizer = optim.SGD(parameter_list, lr=0.01, momentum=0.9)
optimizer = optim.SGD(parameter_list, lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
####################
# 2. train network #
####################
global_step = 0
bestAcc = 0.0
for epoch in range(params.num_epochs):
# set train state for Dropout and BN layers
model.train()
# zip source and target data pair
len_dataloader = min(len(src_data_loader), len(tgt_data_loader))
data_zip = enumerate(zip(src_data_loader, tgt_data_loader))
for step, ((images_src, class_src), (images_tgt, _)) in data_zip:
p = float(step + epoch * len_dataloader) / \
params.num_epochs / len_dataloader
alpha = 2. / (1. + np.exp(-10 * p)) - 1
if params.lr_adjust_flag == 'simple':
lr = adjust_learning_rate(optimizer, p)
else:
lr = adjust_learning_rate_office(optimizer, p)
if not logger == None:
logger.add_scalar('lr', lr, global_step)
# prepare domain label
size_src = len(images_src)
size_tgt = len(images_tgt)
label_src = torch.zeros(size_src).long().to(device) # source 0
label_tgt = torch.ones(size_tgt).long().to(device) # target 1
# make images variable
class_src = class_src.to(device)
images_src = images_src.to(device)
images_tgt = images_tgt.to(device)
# zero gradients for optimizer
optimizer.zero_grad()
# train on source domain
src_class_output, src_domain_output = model(input_data=images_src, alpha=alpha)
src_loss_class = criterion(src_class_output, class_src)
src_loss_domain = criterion(src_domain_output, label_src)
# train on target domain
_, tgt_domain_output = model(input_data=images_tgt, alpha=alpha)
tgt_loss_domain = criterion(tgt_domain_output, label_tgt)
loss = src_loss_class + src_loss_domain + tgt_loss_domain
if params.src_only_flag:
loss = src_loss_class
# optimize dann
loss.backward()
optimizer.step()
global_step += 1
# print step info
if not logger == None:
logger.add_scalar('src_loss_class', src_loss_class.item(), global_step)
logger.add_scalar('src_loss_domain', src_loss_domain.item(), global_step)
logger.add_scalar('tgt_loss_domain', tgt_loss_domain.item(), global_step)
logger.add_scalar('loss', loss.item(), global_step)
if ((step + 1) % params.log_step == 0):
print(
"Epoch [{:4d}/{}] Step [{:2d}/{}]: src_loss_class={:.6f}, src_loss_domain={:.6f}, tgt_loss_domain={:.6f}, loss={:.6f}"
.format(epoch + 1, params.num_epochs, step + 1, len_dataloader, src_loss_class.data.item(),
src_loss_domain.data.item(), tgt_loss_domain.data.item(), loss.data.item()))
# eval model
if ((epoch + 1) % params.eval_step == 0):
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, loggi, flag='target')
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, loggi, flag='source')
loggi.info('\n')
if tgt_acc > bestAcc:
bestAcc = tgt_acc
bestAccS = src_acc
save_model(model, params.model_root,
params.src_dataset + '-' + params.tgt_dataset + "-dann-best.pt")
if not logger == None:
logger.add_scalar('src_test_loss', src_test_loss, global_step)
logger.add_scalar('src_acc', src_acc, global_step)
logger.add_scalar('src_acc_domain', src_acc_domain, global_step)
logger.add_scalar('tgt_test_loss', tgt_test_loss, global_step)
logger.add_scalar('tgt_acc', tgt_acc, global_step)
logger.add_scalar('tgt_acc_domain', tgt_acc_domain, global_step)
# save final model
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt")
loggi.info('\n============ Summary ============= \n')
loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS))
loggi.info('Accuracy of the %s dataset: %f' % (params.tgt_dataset, bestAcc))
try:
global_step = 0
bestAcc = 0.0
for epoch in range(params.num_epochs):
# set train state for Dropout and BN layers
model.train()
# zip source and target data pair
len_dataloader = min(len(src_data_loader), len(tgt_data_loader))
data_zip = enumerate(zip(src_data_loader, tgt_data_loader))
for step, ((images_src, class_src), (images_tgt, _)) in data_zip:
p = float(step + epoch * len_dataloader) / \
params.num_epochs / len_dataloader
alpha = 2. / (1. + np.exp(-10 * p)) - 1
if params.lr_adjust_flag == 'simple':
lr = adjust_learning_rate(optimizer, p)
else:
lr = adjust_learning_rate_office(optimizer, p)
if not logger == None:
logger.add_scalar('lr', lr, global_step)
# prepare domain label
size_src = len(images_src)
size_tgt = len(images_tgt)
label_src = torch.zeros(size_src).long().to(device) # source 0
label_tgt = torch.ones(size_tgt).long().to(device) # target 1
# make images variable
class_src = class_src.to(device)
images_src = images_src.to(device)
images_tgt = images_tgt.to(device)
# zero gradients for optimizer
optimizer.zero_grad()
# train on source domain
src_class_output, src_domain_output = model(input_data=images_src, alpha=alpha)
src_loss_class = criterion(src_class_output, class_src)
src_loss_domain = criterion(src_domain_output, label_src)
# train on target domain
_, tgt_domain_output = model(input_data=images_tgt, alpha=alpha)
tgt_loss_domain = criterion(tgt_domain_output, label_tgt)
loss = src_loss_class + src_loss_domain + tgt_loss_domain
if params.src_only_flag:
loss = src_loss_class
# optimize dann
loss.backward()
optimizer.step()
global_step += 1
# print step info
if not logger == None:
logger.add_scalar('src_loss_class', src_loss_class.item(), global_step)
logger.add_scalar('src_loss_domain', src_loss_domain.item(), global_step)
logger.add_scalar('tgt_loss_domain', tgt_loss_domain.item(), global_step)
logger.add_scalar('loss', loss.item(), global_step)
if ((step + 1) % params.log_step == 0):
print(
"Epoch [{:4d}/{}] Step [{:2d}/{}]: src_loss_class={:.6f}, src_loss_domain={:.6f}, tgt_loss_domain={:.6f}, loss={:.6f}"
.format(epoch + 1, params.num_epochs, step + 1, len_dataloader, src_loss_class.data.item(),
src_loss_domain.data.item(), tgt_loss_domain.data.item(), loss.data.item()))
# eval model
if ((epoch + 1) % params.eval_step == 0):
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, loggi, flag='target')
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, loggi, flag='source')
loggi.info('\n')
if tgt_acc > bestAcc:
bestAcc = tgt_acc
bestAccS = src_acc
save_model(model, params.model_root,
params.src_dataset + '-' + params.tgt_dataset + "-dann-best.pt")
if not logger == None:
logger.add_scalar('src_test_loss', src_test_loss, global_step)
logger.add_scalar('src_acc', src_acc, global_step)
logger.add_scalar('src_acc_domain', src_acc_domain, global_step)
logger.add_scalar('tgt_test_loss', tgt_test_loss, global_step)
logger.add_scalar('tgt_acc', tgt_acc, global_step)
logger.add_scalar('tgt_acc_domain', tgt_acc_domain, global_step)
except KeyboardInterrupt as ke:
loggi.info('Saving the final weights before quitting')
# save final model
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt")
loggi.info('\n============ Summary ============= \n')
loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS))
loggi.info('Accuracy of the %s dataset: %f' % (params.tgt_dataset, bestAcc))
return model

Loading…
Cancel
Save