diff --git a/core/dann.py b/core/dann.py index 80b6239..c9462f5 100644 --- a/core/dann.py +++ b/core/dann.py @@ -7,37 +7,37 @@ import torch.optim as optim import params from utils import make_variable, save_model import numpy as np -from core.test import eval +from core.test import eval, eval_src import torch.backends.cudnn as cudnn cudnn.benchmark = True -def train_dann(dann, src_data_loader, tgt_data_loader, tgt_data_loader_eval): +def train_dann(model, src_data_loader, tgt_data_loader, tgt_data_loader_eval): """Train dann.""" #################### # 1. setup network # #################### - # set train state for Dropout and BN layers - dann.train() - # setup criterion and optimizer - optimizer = optim.Adam(dann.parameters(), lr=params.lr) + parameter_list = [ + {"params": model.features.parameters(), "lr": 1e-5}, + {"params": model.classifier.parameters(), "lr": 1e-4}, + {"params": model.discriminator.parameters(), "lr": 1e-4} + ] + optimizer = optim.Adam(parameter_list) - criterion = nn.NLLLoss() + criterion = nn.CrossEntropyLoss() - for p in dann.parameters(): + for p in model.parameters(): p.requires_grad = True #################### # 2. train network # #################### - # prepare domain label - label_src = make_variable(torch.zeros(params.batch_size).long()) # source 0 - label_tgt = make_variable(torch.ones(params.batch_size).long()) # target 1 - for epoch in range(params.num_epochs): + # set train state for Dropout and BN layers + model.train() # zip source and target data pair len_dataloader = min(len(src_data_loader), len(tgt_data_loader)) data_zip = enumerate(zip(src_data_loader, tgt_data_loader)) @@ -46,6 +46,12 @@ def train_dann(dann, src_data_loader, tgt_data_loader, tgt_data_loader_eval): p = float(step + epoch * len_dataloader) / params.num_epochs / len_dataloader alpha = 2. / (1. + np.exp(-10 * p)) - 1 + # prepare domain label + size_src = len(images_src) + size_tgt = len(images_tgt) + label_src = make_variable(torch.zeros(size_src).long()) # source 0 + label_tgt = make_variable(torch.ones(size_tgt).long()) # target 1 + # make images variable class_src = make_variable(class_src) images_src = make_variable(images_src) @@ -55,12 +61,12 @@ def train_dann(dann, src_data_loader, tgt_data_loader, tgt_data_loader_eval): optimizer.zero_grad() # train on source domain - src_class_output, src_domain_output = dann(input_data=images_src, alpha=alpha) + src_class_output, src_domain_output = model(input_data=images_src, alpha=alpha) src_loss_class = criterion(src_class_output, class_src) src_loss_domain = criterion(src_domain_output, label_src) # train on target domain - _, tgt_domain_output = dann(input_data=images_tgt, alpha=alpha) + _, tgt_domain_output = model(input_data=images_tgt, alpha=alpha) tgt_loss_domain = criterion(tgt_domain_output, label_tgt) loss = src_loss_class + src_loss_domain + tgt_loss_domain @@ -71,7 +77,7 @@ def train_dann(dann, src_data_loader, tgt_data_loader, tgt_data_loader_eval): # print step info if ((step + 1) % params.log_step == 0): - print("Epoch [{}/{}] Step [{}/{}]: src_loss_class={}, src_loss_domain={}, tgt_loss_domain={}, loss={}" + print("Epoch [{:4d}/{}] Step [{:2d}/{}]: src_loss_class={:.6f}, src_loss_domain={:.6f}, tgt_loss_domain={:.6f}, loss={:.6f}" .format(epoch + 1, params.num_epochs, step + 1, @@ -83,14 +89,16 @@ def train_dann(dann, src_data_loader, tgt_data_loader, tgt_data_loader_eval): # eval model on test set if ((epoch + 1) % params.eval_step == 0): - eval(dann, tgt_data_loader_eval) - dann.train() + print("eval on target domain") + eval(model, tgt_data_loader_eval) + print("eval on source domain") + eval_src(model, src_data_loader) # save model parameters if ((epoch + 1) % params.save_step == 0): - save_model(dann, params.src_dataset + '-' + params.tgt_dataset + "-dann-{}.pt".format(epoch + 1)) + save_model(model, params.src_dataset + '-' + params.tgt_dataset + "-dann-{}.pt".format(epoch + 1)) # save final model - save_model(dann, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") + save_model(model, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") - return dann \ No newline at end of file + return model diff --git a/datasets/office.py b/datasets/office.py new file mode 100644 index 0000000..98c5998 --- /dev/null +++ b/datasets/office.py @@ -0,0 +1,30 @@ +"""Dataset setting and data loader for Office.""" + +import torch +from torchvision import datasets, transforms +import torch.utils.data as data +import os +import params + + +def get_office(train, category): + """Get Office datasets loader.""" + # image pre-processing + pre_process = transforms.Compose([transforms.Resize(params.office_image_size), + transforms.ToTensor(), + transforms.Normalize( + mean=params.imagenet_dataset_mean, + std=params.imagenet_dataset_mean)]) + + # datasets and data_loader + office_dataset = datasets.ImageFolder( + os.path.join(params.dataset_root, 'office', category, 'images'), + transform=pre_process) + + office_dataloader = torch.utils.data.DataLoader( + dataset=office_dataset, + batch_size=params.batch_size, + shuffle=True, + num_workers=8) + + return office_dataloader \ No newline at end of file diff --git a/datasets/officecaltech.py b/datasets/officecaltech.py new file mode 100644 index 0000000..70c7d5c --- /dev/null +++ b/datasets/officecaltech.py @@ -0,0 +1,30 @@ +"""Dataset setting and data loader for Office_Caltech_10.""" + +import torch +from torchvision import datasets, transforms +import torch.utils.data as data +import os +import params + + +def get_officecaltech(train, category): + """Get Office_Caltech_10 datasets loader.""" + # image pre-processing + pre_process = transforms.Compose([transforms.Resize(params.office_image_size), + transforms.ToTensor(), + transforms.Normalize( + mean=params.imagenet_dataset_mean, + std=params.imagenet_dataset_mean)]) + + # datasets and data_loader + officecaltech_dataset = datasets.ImageFolder( + os.path.join(params.dataset_root, 'office_caltech_10', category), + transform=pre_process) + + officecaltech_dataloader = torch.utils.data.DataLoader( + dataset=officecaltech_dataset, + batch_size=params.batch_size, + shuffle=True, + num_workers=8) + + return officecaltech_dataloader \ No newline at end of file diff --git a/main_office.py b/main_office.py new file mode 100644 index 0000000..c3b5d42 --- /dev/null +++ b/main_office.py @@ -0,0 +1,32 @@ +from models.model import SVHNmodel, Classifier + +from core.dann import train_dann +from core.test import eval +from models.model import AlexModel + +import params +from utils import get_data_loader, init_model, init_random_seed + +# init random seed +init_random_seed(params.manual_seed) + +# load dataset +src_data_loader = get_data_loader(params.src_dataset) +tgt_data_loader = get_data_loader(params.tgt_dataset) + +# load dann model +dann = init_model(net=AlexModel(), restore=params.dann_restore) + +# train dann model +print("Start training dann model.") + +if not (dann.restored and params.dann_restore): + dann = train_dann(dann, src_data_loader, tgt_data_loader, tgt_data_loader) + +# eval dann model +print("Evaluating dann for source domain") +eval(dann, src_data_loader) +print("Evaluating dann for target domain") +eval(dann, tgt_data_loader) + +print('done') \ No newline at end of file diff --git a/models/model.py b/models/model.py index 86394f7..1b6c865 100644 --- a/models/model.py +++ b/models/model.py @@ -1,13 +1,54 @@ """DANN model.""" import torch.nn as nn -from functions import ReverseLayerF +from .functions import ReverseLayerF +from torchvision import models +import params -class CNNModel(nn.Module): +class Classifier(nn.Module): + """ SVHN architecture without discriminator""" def __init__(self): - super(CNNModel, self).__init__() + super(Classifier, self).__init__() + self.restored = False + + self.feature = nn.Sequential() + self.feature.add_module('f_conv1', nn.Conv2d(1, 64, kernel_size=5)) + self.feature.add_module('f_bn1', nn.BatchNorm2d(64)) + self.feature.add_module('f_pool1', nn.MaxPool2d(2)) + self.feature.add_module('f_relu1', nn.ReLU(True)) + self.feature.add_module('f_conv2', nn.Conv2d(64, 50, kernel_size=5)) + self.feature.add_module('f_bn2', nn.BatchNorm2d(50)) + self.feature.add_module('f_drop1', nn.Dropout2d()) + self.feature.add_module('f_pool2', nn.MaxPool2d(2)) + self.feature.add_module('f_relu2', nn.ReLU(True)) + + self.class_classifier = nn.Sequential() + self.class_classifier.add_module('c_fc1', nn.Linear(50 * 4 * 4, 100)) + self.class_classifier.add_module('c_bn1', nn.BatchNorm2d(100)) + self.class_classifier.add_module('c_relu1', nn.ReLU(True)) + self.class_classifier.add_module('c_drop1', nn.Dropout2d()) + self.class_classifier.add_module('c_fc2', nn.Linear(100, 100)) + self.class_classifier.add_module('c_bn2', nn.BatchNorm2d(100)) + self.class_classifier.add_module('c_relu2', nn.ReLU(True)) + self.class_classifier.add_module('c_fc3', nn.Linear(100, 10)) + self.class_classifier.add_module('c_softmax', nn.LogSoftmax(dim=1)) + + def forward(self, input_data): + input_data = input_data.expand(input_data.data.shape[0], 1, 28, 28) + feature = self.feature(input_data) + feature = feature.view(-1, 50 * 4 * 4) + class_output = self.class_classifier(feature) + + return class_output + + +class SVHNmodel(nn.Module): + """ SVHN architecture""" + + def __init__(self): + super(SVHNmodel, self).__init__() self.restored = False self.feature = nn.Sequential() @@ -48,3 +89,58 @@ class CNNModel(nn.Module): domain_output = self.domain_classifier(reverse_feature) return class_output, domain_output + + +class AlexModel(nn.Module): + """ AlexNet pretrained on imagenet for Office dataset""" + + def __init__(self): + super(AlexModel, self).__init__() + self.restored = False + model_alexnet = models.alexnet(pretrained=True) + + self.features = model_alexnet.features + + # self.classifier = nn.Sequential() + # for i in range(5): + # self.classifier.add_module( + # "classifier" + str(i), model_alexnet.classifier[i]) + # self.__in_features = model_alexnet.classifier[4].in_features + # self.classifier.add_module('classifier5', nn.Dropout()) + # self.classifier.add_module('classifier6', nn.Linear(self.__in_features, 256)) + # self.classifier.add_module('classifier7', nn.BatchNorm2d(256)) + # self.classifier.add_module('classifier8', nn.ReLU()) + # self.classifier.add_module('classifier9', nn.Dropout(0.5)) + # self.classifier.add_module('classifier10', nn.Linear(256, params.class_num_src)) + self.classifier = nn.Sequential( + nn.Dropout(0.5), + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + nn.Dropout(0.5), + nn.Linear(4096, 256), + nn.ReLU(inplace=True), + nn.Linear(256, params.class_num_src), + ) + + self.discriminator = nn.Sequential( + nn.Linear(256 * 6 * 6, 1024), + nn.ReLU(), + nn.Dropout(0.5), + nn.Linear(1024, 1024), + nn.ReLU(), + nn.Dropout(0.5), + nn.Linear(1024, 2), + ) + + def forward(self, input_data, alpha): + input_data = input_data.expand(input_data.data.shape[0], 3, 227, 227) + feature = self.features(input_data) + feature = feature.view(-1, 256 * 6 * 6) + + reverse_feature = ReverseLayerF.apply(feature, alpha) + + class_output = self.classifier(feature) + + domain_output = self.discriminator(reverse_feature) + + return class_output, domain_output diff --git a/params.py b/params.py index 8868212..2dd1011 100644 --- a/params.py +++ b/params.py @@ -11,16 +11,22 @@ dataset_mean_value = 0.5 dataset_std_value = 0.5 dataset_mean = (dataset_mean_value, dataset_mean_value, dataset_mean_value) dataset_std = (dataset_std_value, dataset_std_value, dataset_std_value) -batch_size = 128 -image_size = 28 + +imagenet_dataset_mean = (0.485, 0.456, 0.406) +imagenet_dataset_std = (0.229, 0.224, 0.225) + +batch_size = 64 +digit_image_size = 28 +office_image_size = 227 # params for source dataset -src_dataset = "SVHN" +src_dataset = "amazon31" src_model_trained = True src_classifier_restore = os.path.join(model_root,src_dataset + '-source-classifier-final.pt') +class_num_src = 31 # params for target dataset -tgt_dataset = "MNIST" +tgt_dataset = "webcam31" tgt_model_trained = True dann_restore = os.path.join(model_root , src_dataset + '-' + tgt_dataset + '-dann-final.pt') @@ -31,10 +37,18 @@ save_step_src = 20 eval_step_src = 20 # params for training dann -num_epochs = 400 -log_step = 50 -save_step = 50 -eval_step = 20 + +## for digit +# num_epochs = 400 +# log_step = 100 +# save_step = 20 +# eval_step = 20 + +## for office +num_epochs = 1000 +log_step = 10 # iters +save_step = 500 +eval_step = 5 # epochs manual_seed = 8888 alpha = 0