From ac57dc2a1a845452d2ba7a458841f9f971781bcd Mon Sep 17 00:00:00 2001 From: wogong Date: Fri, 25 May 2018 00:28:02 +0800 Subject: [PATCH] remove params.py, add split task file, etc. --- core/dann.py | 33 +++++++++------ core/pretrain.py | 3 +- core/test.py | 62 ----------------------------- datasets/mnist.py | 11 +++-- datasets/mnistm.py | 13 +++--- datasets/office.py | 9 ++--- datasets/officecaltech.py | 9 ++--- datasets/svhn.py | 14 +++---- main.py | 49 ----------------------- mnist_mnistm.py | 76 +++++++++++++++++++++++++++++++++++ models/functions.py | 1 - models/model.py | 84 +++++++++++++++++++++++---------------- params.py | 50 ----------------------- svhn_mnist.py | 75 ++++++++++++++++++++++++++++++++++ utils.py | 32 +++++++-------- 15 files changed, 262 insertions(+), 259 deletions(-) delete mode 100644 main.py create mode 100644 mnist_mnistm.py delete mode 100644 params.py create mode 100644 svhn_mnist.py diff --git a/core/dann.py b/core/dann.py index b96c7af..a1bf448 100644 --- a/core/dann.py +++ b/core/dann.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn import torch.optim as optim -import params from utils import make_variable, save_model import numpy as np from core.test import eval, eval_src @@ -12,19 +11,19 @@ from core.test import eval, eval_src import torch.backends.cudnn as cudnn cudnn.benchmark = True -def train_dann(model, src_data_loader, tgt_data_loader, tgt_data_loader_eval): +def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval): """Train dann.""" #################### # 1. setup network # #################### # setup criterion and optimizer - parameter_list = [ - {"params": model.features.parameters(), "lr": 1e-5}, - {"params": model.classifier.parameters(), "lr": 1e-4}, - {"params": model.discriminator.parameters(), "lr": 1e-4} - ] - optimizer = optim.Adam(parameter_list) + # parameter_list = [ + # # {"params": model.feature.parameters(), "lr": 1e-5}, + # # {"params": model.classifier.parameters(), "lr": 1e-4}, + # # {"params": model.discriminator.parameters(), "lr": 1e-4} + # # ] + optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) criterion = nn.CrossEntropyLoss() @@ -45,6 +44,7 @@ def train_dann(model, src_data_loader, tgt_data_loader, tgt_data_loader_eval): p = float(step + epoch * len_dataloader) / params.num_epochs / len_dataloader alpha = 2. / (1. + np.exp(-10 * p)) - 1 + adjust_learning_rate(optimizer, p) # prepare domain label size_src = len(images_src) @@ -96,9 +96,18 @@ def train_dann(model, src_data_loader, tgt_data_loader, tgt_data_loader_eval): # save model parameters if ((epoch + 1) % params.save_step == 0): - save_model(model, params.src_dataset + '-' + params.tgt_dataset + "-dann-{}.pt".format(epoch + 1)) + save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-{}.pt".format(epoch + 1)) # save final model - save_model(model, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") - - return model \ No newline at end of file + save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") + + return model + +def adjust_learning_rate(optimizer, p): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr_0 = 0.01 + alpha = 10 + beta = 0.75 + lr = lr_0 / (1 + alpha*p) ** beta + for param_group in optimizer.param_groups: + param_group['lr'] = lr \ No newline at end of file diff --git a/core/pretrain.py b/core/pretrain.py index 7d80fd4..c44f5bb 100644 --- a/core/pretrain.py +++ b/core/pretrain.py @@ -3,11 +3,10 @@ import torch.nn as nn import torch.optim as optim -import params from utils import make_variable, save_model from core.test import eval_src -def train_src(model, data_loader): +def train_src(model, params, data_loader): """Train classifier for source domain.""" #################### # 1. setup network # diff --git a/core/test.py b/core/test.py index 345780c..90c42f3 100644 --- a/core/test.py +++ b/core/test.py @@ -1,69 +1,8 @@ -import os -import torch.backends.cudnn as cudnn import torch.utils.data -from torch.autograd import Variable import torch.nn as nn -import params from utils import make_variable -def test(dataloader, epoch): - - if torch.cuda.is_available(): - cuda = True - cudnn.benchmark = True - - """ training """ - - my_net = torch.load(os.path.join( - params.model_root, params.src_dataset + '_' + params.tgt_dataset + '_model_epoch_' + str(epoch) + '.pth' - )) - my_net = my_net.eval() - - if cuda: - my_net = my_net.cuda() - - len_dataloader = len(dataloader) - data_target_iter = iter(dataloader) - - i = 0 - n_total = 0 - n_correct = 0 - - while i < len_dataloader: - - # test model using target data - data_target = data_target_iter.next() - t_img, t_label = data_target - - batch_size = len(t_label) - - input_img = torch.FloatTensor(batch_size, 3, params.digit_image_size, params.digit_image_size) - class_label = torch.LongTensor(batch_size) - - if cuda: - t_img = t_img.cuda() - t_label = t_label.cuda() - input_img = input_img.cuda() - class_label = class_label.cuda() - - input_img.resize_as_(t_img).copy_(t_img) - class_label.resize_as_(t_label).copy_(t_label) - inputv_img = Variable(input_img) - classv_label = Variable(class_label) - - class_output, _ = my_net(input_data=inputv_img, alpha=params.alpha) - pred = class_output.data.max(1, keepdim=True)[1] - n_correct += pred.eq(classv_label.data.view_as(pred)).cpu().sum() - n_total += batch_size - - i += 1 - - accu = n_correct * 1.0 / n_total - - print('epoch: %d, accuracy: %f' % (epoch, accu)) - - def test_from_save(model, saved_model, data_loader): """Evaluate classifier for source domain.""" # set eval state for Dropout and BN layers @@ -95,7 +34,6 @@ def test_from_save(model, saved_model, data_loader): print("Avg Loss = {}, Avg Accuracy = {:.2%}".format(loss, acc)) - def eval(model, data_loader): """Evaluate model for dataset.""" # set eval state for Dropout and BN layers diff --git a/datasets/mnist.py b/datasets/mnist.py index 667bae7..4d243c0 100644 --- a/datasets/mnist.py +++ b/datasets/mnist.py @@ -5,9 +5,7 @@ import torch from torchvision import datasets, transforms import os -import params - -def get_mnist(train): +def get_mnist(dataset_root, batch_size, train): """Get MNIST datasets loader.""" # image pre-processing pre_process = transforms.Compose([transforms.ToTensor(), @@ -17,7 +15,7 @@ def get_mnist(train): )]) # datasets and data loader - mnist_dataset = datasets.MNIST(root=os.path.join(params.dataset_root,'mnist'), + mnist_dataset = datasets.MNIST(root=os.path.join(dataset_root,'mnist'), train=train, transform=pre_process, download=False) @@ -25,8 +23,9 @@ def get_mnist(train): mnist_data_loader = torch.utils.data.DataLoader( dataset=mnist_dataset, - batch_size=params.batch_size, + batch_size=batch_size, shuffle=True, - drop_last=True) + drop_last=True, + num_workers=8) return mnist_data_loader \ No newline at end of file diff --git a/datasets/mnistm.py b/datasets/mnistm.py index 5fe1fa4..7d2aa2b 100644 --- a/datasets/mnistm.py +++ b/datasets/mnistm.py @@ -5,7 +5,6 @@ from torchvision import transforms import torch.utils.data as data from PIL import Image import os -import params class GetLoader(data.Dataset): def __init__(self, data_root, data_list, transform=None): @@ -38,7 +37,7 @@ class GetLoader(data.Dataset): def __len__(self): return self.n_data -def get_mnistm(train): +def get_mnistm(dataset_root, batch_size, train): """Get MNISTM datasets loader.""" # image pre-processing pre_process = transforms.Compose([transforms.Resize(28), @@ -50,21 +49,21 @@ def get_mnistm(train): # datasets and data_loader if train: - train_list = os.path.join(params.dataset_root, 'mnist_m','mnist_m_train_labels.txt') + train_list = os.path.join(dataset_root, 'mnist_m','mnist_m_train_labels.txt') mnistm_dataset = GetLoader( - data_root=os.path.join(params.dataset_root, 'mnist_m', 'mnist_m_train'), + data_root=os.path.join(dataset_root, 'mnist_m', 'mnist_m_train'), data_list=train_list, transform=pre_process) else: - train_list = os.path.join(params.dataset_root, 'mnist_m', 'mnist_m_test_labels.txt') + train_list = os.path.join(dataset_root, 'mnist_m', 'mnist_m_test_labels.txt') mnistm_dataset = GetLoader( - data_root=os.path.join(params.dataset_root, 'mnist_m', 'mnist_m_test'), + data_root=os.path.join(dataset_root, 'mnist_m', 'mnist_m_test'), data_list=train_list, transform=pre_process) mnistm_dataloader = torch.utils.data.DataLoader( dataset=mnistm_dataset, - batch_size=params.batch_size, + batch_size=batch_size, shuffle=True, num_workers=8) diff --git a/datasets/office.py b/datasets/office.py index e2f3289..3b33d6c 100644 --- a/datasets/office.py +++ b/datasets/office.py @@ -4,13 +4,12 @@ import torch from torchvision import datasets, transforms import torch.utils.data as data import os -import params -def get_office(train, category): +def get_office(dataset_root, batch_size, category): """Get Office datasets loader.""" # image pre-processing - pre_process = transforms.Compose([transforms.Resize(params.office_image_size), + pre_process = transforms.Compose([transforms.Resize(227), transforms.ToTensor(), transforms.Normalize( mean=(0.485, 0.456, 0.406), @@ -19,12 +18,12 @@ def get_office(train, category): # datasets and data_loader office_dataset = datasets.ImageFolder( - os.path.join(params.dataset_root, 'office', category, 'images'), + os.path.join(dataset_root, 'office', category, 'images'), transform=pre_process) office_dataloader = torch.utils.data.DataLoader( dataset=office_dataset, - batch_size=params.batch_size, + batch_size=batch_size, shuffle=True, num_workers=4) diff --git a/datasets/officecaltech.py b/datasets/officecaltech.py index 59eeda4..6aa4330 100644 --- a/datasets/officecaltech.py +++ b/datasets/officecaltech.py @@ -4,13 +4,12 @@ import torch from torchvision import datasets, transforms import torch.utils.data as data import os -import params -def get_officecaltech(train, category): +def get_officecaltech(dataset_root, batch_size, category): """Get Office_Caltech_10 datasets loader.""" # image pre-processing - pre_process = transforms.Compose([transforms.Resize(params.office_image_size), + pre_process = transforms.Compose([transforms.Resize(227), transforms.ToTensor(), transforms.Normalize( mean=(0.485, 0.456, 0.406), @@ -19,12 +18,12 @@ def get_officecaltech(train, category): # datasets and data_loader officecaltech_dataset = datasets.ImageFolder( - os.path.join(params.dataset_root, 'office_caltech_10', category), + os.path.join(dataset_root, 'office_caltech_10', category), transform=pre_process) officecaltech_dataloader = torch.utils.data.DataLoader( dataset=officecaltech_dataset, - batch_size=params.batch_size, + batch_size=batch_size, shuffle=True, num_workers=4) diff --git a/datasets/svhn.py b/datasets/svhn.py index d9db05f..64b6e25 100644 --- a/datasets/svhn.py +++ b/datasets/svhn.py @@ -1,18 +1,14 @@ """Dataset setting and data loader for SVHN.""" - import torch from torchvision import datasets, transforms import os -import params - -def get_svhn(train): +def get_svhn(dataset_root, batch_size, train): """Get SVHN datasets loader.""" # image pre-processing - pre_process = transforms.Compose([transforms.Grayscale(), - transforms.Resize(params.digit_image_size), + pre_process = transforms.Compose([transforms.Resize(28), transforms.ToTensor(), transforms.Normalize( mean=(0.5, 0.5, 0.5), @@ -21,19 +17,19 @@ def get_svhn(train): # datasets and data loader if train: - svhn_dataset = datasets.SVHN(root=os.path.join(params.dataset_root,'svhn'), + svhn_dataset = datasets.SVHN(root=os.path.join(dataset_root,'svhn'), split='train', transform=pre_process, download=True) else: - svhn_dataset = datasets.SVHN(root=os.path.join(params.dataset_root,'svhn'), + svhn_dataset = datasets.SVHN(root=os.path.join(dataset_root,'svhn'), split='test', transform=pre_process, download=True) svhn_data_loader = torch.utils.data.DataLoader( dataset=svhn_dataset, - batch_size=params.batch_size, + batch_size=batch_size, shuffle=True, drop_last=True) diff --git a/main.py b/main.py deleted file mode 100644 index ce2b3fa..0000000 --- a/main.py +++ /dev/null @@ -1,49 +0,0 @@ -from models.model import SVHNmodel, Classifier - -from core.dann import train_dann -from core.test import eval, eval_src -from core.pretrain import train_src - -import params -from utils import get_data_loader, init_model, init_random_seed - -# init random seed -init_random_seed(params.manual_seed) - -# load dataset -src_data_loader = get_data_loader(params.src_dataset) -src_data_loader_eval = get_data_loader(params.src_dataset, train=False) -tgt_data_loader = get_data_loader(params.tgt_dataset) -tgt_data_loader_eval = get_data_loader(params.tgt_dataset, train=False) - -# load source classifier -src_classifier = init_model(net=Classifier(), restore=params.src_classifier_restore) - -# train source model -print("=== Training classifier for source domain ===") - -if not (src_classifier.restored and params.src_model_trained): - src_classifier = train_src(src_classifier, src_data_loader) - -# eval source model on both source and target domain -print("=== Evaluating source classifier for source domain ===") -eval_src(src_classifier, src_data_loader_eval) -print("=== Evaluating source classifier for target domain ===") -eval_src(src_classifier, tgt_data_loader_eval) - -# load dann model -dann = init_model(net=SVHNmodel(), restore=params.dann_restore) - -# train dann model -print("=== Training dann model ===") - -if not (dann.restored and params.dann_restore): - dann = train_dann(dann, src_data_loader, tgt_data_loader, tgt_data_loader_eval) -w -# eval dann model -print("=== Evaluating dann for source domain ===") -eval(dann, src_data_loader_eval) -print("=== Evaluating dann for target domain ===") -eval(dann, tgt_data_loader_eval) - -print('done') \ No newline at end of file diff --git a/mnist_mnistm.py b/mnist_mnistm.py new file mode 100644 index 0000000..3f8140d --- /dev/null +++ b/mnist_mnistm.py @@ -0,0 +1,76 @@ +import os + +from models.model import MNISTmodel +from core.dann import train_dann +from utils import get_data_loader, init_model, init_random_seed + + +class Config(object): + # params for path + dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) + model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN')) + + # params for datasets and data loader + batch_size = 128 + + # params for source dataset + src_dataset = "mnist" + src_model_trained = True + src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt') + class_num_src = 31 + + # params for target dataset + tgt_dataset = "mnistm" + tgt_model_trained = True + dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt') + + # params for pretrain + num_epochs_src = 100 + log_step_src = 10 + save_step_src = 50 + eval_step_src = 20 + + # params for training dann + + ## for digit + num_epochs = 100 + log_step = 20 + save_step = 50 + eval_step = 5 + + ## for office + # num_epochs = 1000 + # log_step = 10 # iters + # save_step = 500 + # eval_step = 5 # epochs + + manual_seed = 8888 + alpha = 0 + + # params for optimizing models + lr = 2e-4 + +params = Config() + +# init random seed +init_random_seed(params.manual_seed) + +# load dataset +src_data_loader = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=True) +src_data_loader_eval = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=False) +tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=True) +tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=False) + +# load dann model +dann = init_model(net=MNISTmodel(), restore=None) + +# train dann model +print("Training dann model") +if not (dann.restored and params.dann_restore): + dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval) + +# eval dann model +print("Evaluating dann for source domain {}".format(params.src_dataset)) +eval(dann, src_data_loader_eval) +print("Evaluating dann for target domain {}".format(params.tgt_dataset)) +eval(dann, tgt_data_loader_eval) \ No newline at end of file diff --git a/models/functions.py b/models/functions.py index f33f2a6..db079d3 100644 --- a/models/functions.py +++ b/models/functions.py @@ -12,7 +12,6 @@ class ReverseLayerF(Function): @staticmethod def backward(ctx, grad_output): output = grad_output.neg() * ctx.alpha - #print("reverse gradient is {}".format(output)) return output, None diff --git a/models/model.py b/models/model.py index a32af29..24d635a 100644 --- a/models/model.py +++ b/models/model.py @@ -43,37 +43,46 @@ class Classifier(nn.Module): return class_output class MNISTmodel(nn.Module): - """ MNIST architecture""" + """ MNIST architecture + +Dropout2d, 84% ~ 73% + -Dropout2d, 50% ~ 73% + """ def __init__(self): super(MNISTmodel, self).__init__() self.restored = False self.feature = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(5, 5)), # 1 28 28, 32 24 24 + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(5, 5)), # 3 28 28, 32 24 24 + nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(2, 2)), # 32 12 12 nn.Conv2d(in_channels=32, out_channels=48, kernel_size=(5, 5)), # 48 8 8 + nn.BatchNorm2d(48), + nn.Dropout2d(), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(2, 2)), # 48 4 4 ) self.classifier = nn.Sequential( nn.Linear(48*4*4, 100), + nn.BatchNorm1d(100), nn.ReLU(inplace=True), nn.Linear(100, 100), + nn.BatchNorm1d(100), nn.ReLU(inplace=True), nn.Linear(100, 10), ) self.discriminator = nn.Sequential( nn.Linear(48*4*4, 100), + nn.BatchNorm1d(100), nn.ReLU(inplace=True), nn.Linear(100, 2), ) def forward(self, input_data, alpha): - input_data = input_data.expand(input_data.data.shape[0], 1, 28, 28) + input_data = input_data.expand(input_data.data.shape[0], 3, 28, 28) feature = self.feature(input_data) feature = feature.view(-1, 48 * 4 * 4) reverse_feature = ReverseLayerF.apply(feature, alpha) @@ -83,48 +92,55 @@ class MNISTmodel(nn.Module): return class_output, domain_output class SVHNmodel(nn.Module): - """ SVHN architecture""" + """ SVHN architecture + I don't know how to implement the paper's structure + + """ def __init__(self): super(SVHNmodel, self).__init__() self.restored = False - self.feature = nn.Sequential() - self.feature.add_module('f_conv1', nn.Conv2d(1, 64, kernel_size=5)) - self.feature.add_module('f_bn1', nn.BatchNorm2d(64)) - self.feature.add_module('f_pool1', nn.MaxPool2d(2)) - self.feature.add_module('f_relu1', nn.ReLU(True)) - self.feature.add_module('f_conv2', nn.Conv2d(64, 50, kernel_size=5)) - self.feature.add_module('f_bn2', nn.BatchNorm2d(50)) - self.feature.add_module('f_drop1', nn.Dropout2d()) - self.feature.add_module('f_pool2', nn.MaxPool2d(2)) - self.feature.add_module('f_relu2', nn.ReLU(True)) + self.feature = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(5, 5), stride=(1, 1)), # 3 28 28, 64 24 24 + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=(2, 2)), # 64 12 12 + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(5, 5)), # 64 8 8 + nn.BatchNorm2d(64), + nn.Dropout2d(), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), # 64 4 4 + nn.ReLU(inplace=True), + ) - self.class_classifier = nn.Sequential() - self.class_classifier.add_module('c_fc1', nn.Linear(50 * 4 * 4, 100)) - self.class_classifier.add_module('c_bn1', nn.BatchNorm2d(100)) - self.class_classifier.add_module('c_relu1', nn.ReLU(True)) - self.class_classifier.add_module('c_drop1', nn.Dropout2d()) - self.class_classifier.add_module('c_fc2', nn.Linear(100, 100)) - self.class_classifier.add_module('c_bn2', nn.BatchNorm2d(100)) - self.class_classifier.add_module('c_relu2', nn.ReLU(True)) - self.class_classifier.add_module('c_fc3', nn.Linear(100, 10)) - self.class_classifier.add_module('c_softmax', nn.LogSoftmax(dim=1)) + self.classifier = nn.Sequential( + nn.Linear(64*4*4, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(inplace=True), + nn.Linear(1024, 256), + nn.BatchNorm1d(256), + nn.ReLU(inplace=True), + nn.Linear(256, 10), + ) - self.domain_classifier = nn.Sequential() - self.domain_classifier.add_module('d_fc1', nn.Linear(50 * 4 * 4, 100)) - self.domain_classifier.add_module('d_bn1', nn.BatchNorm2d(100)) - self.domain_classifier.add_module('d_relu1', nn.ReLU(True)) - self.domain_classifier.add_module('d_fc2', nn.Linear(100, 2)) - self.domain_classifier.add_module('d_softmax', nn.LogSoftmax(dim=1)) + self.discriminator = nn.Sequential( + nn.Linear(64*4*4, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(inplace=True), + nn.Linear(1024, 256), + nn.BatchNorm1d(256), + nn.ReLU(inplace=True), + nn.Linear(256, 2), + ) def forward(self, input_data, alpha): - input_data = input_data.expand(input_data.data.shape[0], 1, 28, 28) + input_data = input_data.expand(input_data.data.shape[0], 3, 28, 28) feature = self.feature(input_data) - feature = feature.view(-1, 50 * 4 * 4) + feature = feature.view(-1, 64 * 4 * 4) reverse_feature = ReverseLayerF.apply(feature, alpha) - class_output = self.class_classifier(feature) - domain_output = self.domain_classifier(reverse_feature) + class_output = self.classifier(feature) + domain_output = self.discriminator(reverse_feature) return class_output, domain_output diff --git a/params.py b/params.py deleted file mode 100644 index 8872e48..0000000 --- a/params.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Params for DANN.""" - -import os - -# params for path -dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) -model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN')) - -# params for datasets and data loader - -batch_size = 64 - -office_image_size = 227 - -# params for source dataset -src_dataset = "amazon31" -src_model_trained = True -src_classifier_restore = os.path.join(model_root,src_dataset + '-source-classifier-final.pt') -class_num_src = 31 - -# params for target dataset -tgt_dataset = "webcam31" -tgt_model_trained = True -dann_restore = os.path.join(model_root , src_dataset + '-' + tgt_dataset + '-dann-final.pt') - -# params for pretrain -num_epochs_src = 100 -log_step_src = 10 -save_step_src = 20 -eval_step_src = 20 - -# params for training dann - -## for digit -# num_epochs = 400 -# log_step = 100 -# save_step = 20 -# eval_step = 20 - -## for office -num_epochs = 1000 -log_step = 10 # iters -save_step = 500 -eval_step = 5 # epochs - -manual_seed = 8888 -alpha = 0 - -# params for optimizing models -lr = 2e-4 \ No newline at end of file diff --git a/svhn_mnist.py b/svhn_mnist.py new file mode 100644 index 0000000..a51d4e6 --- /dev/null +++ b/svhn_mnist.py @@ -0,0 +1,75 @@ +import os + +from models.model import SVHNmodel +from core.dann import train_dann +from utils import get_data_loader, init_model, init_random_seed + + +class Config(object): + # params for path + dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) + model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN')) + + # params for datasets and data loader + batch_size = 128 + + # params for source dataset + src_dataset = "svhn" + src_model_trained = True + src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt') + + # params for target dataset + tgt_dataset = "mnist" + tgt_model_trained = True + dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt') + + # params for pretrain + num_epochs_src = 100 + log_step_src = 10 + save_step_src = 50 + eval_step_src = 20 + + # params for training dann + + ## for digit + num_epochs = 200 + log_step = 20 + save_step = 50 + eval_step = 5 + + ## for office + # num_epochs = 1000 + # log_step = 10 # iters + # save_step = 500 + # eval_step = 5 # epochs + + manual_seed = 8888 + alpha = 0 + + # params for optimizing models + lr = 2e-4 + +params = Config() + +# init random seed +init_random_seed(params.manual_seed) + +# load dataset +src_data_loader = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=True) +src_data_loader_eval = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=False) +tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=True) +tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=False) + +# load dann model +dann = init_model(net=SVHNmodel(), restore=None) + +# train dann model +print("Training dann model") +if not (dann.restored and params.dann_restore): + dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval) + +# eval dann model +print("Evaluating dann for source domain {}".format(params.src_dataset)) +eval(dann, src_data_loader_eval) +print("Evaluating dann for target domain {}".format(params.tgt_dataset)) +eval(dann, tgt_data_loader_eval) \ No newline at end of file diff --git a/utils.py b/utils.py index 6cdb9f0..2c8b7df 100644 --- a/utils.py +++ b/utils.py @@ -7,7 +7,6 @@ import torch import torch.backends.cudnn as cudnn from torch.autograd import Variable -import params from datasets import get_mnist, get_mnistm, get_svhn from datasets.office import get_office from datasets.officecaltech import get_officecaltech @@ -57,20 +56,20 @@ def init_random_seed(manual_seed): torch.cuda.manual_seed_all(seed) -def get_data_loader(name, train=True): +def get_data_loader(name, dataset_root, batch_size, train=True): """Get data loader by name.""" - if name == "MNIST": - return get_mnist(train) - elif name == "MNISTM": - return get_mnistm(train) - elif name == "SVHN": - return get_svhn(train) + if name == "mnist": + return get_mnist(dataset_root, batch_size, train) + elif name == "mnistm": + return get_mnistm(dataset_root, batch_size, train) + elif name == "svhn": + return get_svhn(dataset_root, batch_size, train) elif name == "amazon31": - return get_office(train, 'amazon') + return get_office(dataset_root, batch_size, 'amazon') elif name == "webcam31": - return get_office(train, 'webcam') + return get_office(dataset_root, batch_size, 'webcam') elif name == "webcam10": - return get_officecaltech(train, 'webcam') + return get_officecaltech(dataset_root, batch_size, 'webcam') def init_model(net, restore): """Init models with cuda and weights.""" @@ -92,11 +91,10 @@ def init_model(net, restore): return net - -def save_model(net, filename): +def save_model(net, model_root, filename): """Save trained model.""" - if not os.path.exists(params.model_root): - os.makedirs(params.model_root) + if not os.path.exists(model_root): + os.makedirs(model_root) torch.save(net.state_dict(), - os.path.join(params.model_root, filename)) - print("save pretrained model to: {}".format(os.path.join(params.model_root, filename))) \ No newline at end of file + os.path.join(model_root, filename)) + print("save pretrained model to: {}".format(os.path.join(model_root, filename))) \ No newline at end of file