diff --git a/core/train.py b/core/train.py index 8da0c46..9ba407a 100644 --- a/core/train.py +++ b/core/train.py @@ -10,6 +10,105 @@ from utils.utils import save_model import torch.backends.cudnn as cudnn cudnn.benchmark = True +def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger): + """Train dann.""" + #################### + # 1. setup network # + #################### + + # setup criterion and optimizer + + if not params.finetune_flag: + print("training non-office task") + optimizer = optim.SGD(model.parameters(), lr=params.lr, momentum=params.momentum, weight_decay=params.weight_decay) + else: + print("training office task") + parameter_list = [{ + "params": model.features.parameters(), + "lr": 0.001 + }, { + "params": model.fc.parameters(), + "lr": 0.001 + }, { + "params": model.bottleneck.parameters() + }, { + "params": model.classifier.parameters() + }, { + "params": model.discriminator.parameters() + }] + optimizer = optim.SGD(parameter_list, lr=0.01, momentum=0.9) + + criterion = nn.CrossEntropyLoss() + + #################### + # 2. train network # + #################### + global_step = 0 + for epoch in range(params.num_epochs): + # set train state for Dropout and BN layers + model.train() + # zip source and target data pair + len_dataloader = min(len(src_data_loader), len(tgt_data_loader)) + data_zip = enumerate(zip(src_data_loader, tgt_data_loader)) + for step, ((images_src, class_src), (images_tgt, _)) in data_zip: + + p = float(step + epoch * len_dataloader) / \ + params.num_epochs / len_dataloader + alpha = 2. / (1. + np.exp(-10 * p)) - 1 + + if params.lr_adjust_flag == 'simple': + lr = adjust_learning_rate(optimizer, p) + else: + lr = adjust_learning_rate_office(optimizer, p) + logger.add_scalar('lr', lr, global_step) + + # prepare domain label + size_src = len(images_src) + size_tgt = len(images_tgt) + + # make images variable + class_src = class_src.to(device) + images_src = images_src.to(device) + + # zero gradients for optimizer + model.zero_grad() + + # train on source domain + src_class_output, src_domain_output = model(input_data=images_src, alpha=alpha) + src_loss_class = criterion(src_class_output, class_src) + + loss = src_loss_class + + # optimize dann + loss.backward() + optimizer.step() + + global_step += 1 + + # print step info + logger.add_scalar('loss', loss.item(), global_step) + + if ((step + 1) % params.log_step == 0): + print( + "Epoch [{:4d}/{}] Step [{:2d}/{}]: loss={:.6f}".format(epoch + 1, params.num_epochs, step + 1, len_dataloader, loss.data.item())) + + # eval model + if ((epoch + 1) % params.eval_step == 0): + src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source') + tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target') + logger.add_scalar('src_test_loss', src_test_loss, global_step) + logger.add_scalar('src_acc', src_acc, global_step) + + + # save model parameters + if ((epoch + 1) % params.save_step == 0): + save_model(model, params.model_root, + params.src_dataset + '-' + params.tgt_dataset + "-dann-{}.pt".format(epoch + 1)) + + # save final model + save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") + + return model def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger): """Train dann.""" @@ -110,7 +209,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ # eval model if ((epoch + 1) % params.eval_step == 0): - tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader, device, flag='target') + tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target') src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source') logger.add_scalar('src_test_loss', src_test_loss, global_step) logger.add_scalar('src_acc', src_acc, global_step)