diff --git a/.gitignore b/.gitignore index 7bbc71c..569aa65 100644 --- a/.gitignore +++ b/.gitignore @@ -99,3 +99,8 @@ ENV/ # mypy .mypy_cache/ + +# personal +.idea +.DS_Store +main_legacy.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..7824ade --- /dev/null +++ b/README.md @@ -0,0 +1,27 @@ +# PyTorch-DANN + +A pytorch implementation for paper *[Unsupervised Domain Adaptation by Backpropagation](http://sites.skoltech.ru/compvision/projects/grl/)* + + InProceedings (icml2015-ganin15) + Ganin, Y. & Lempitsky, V. + Unsupervised Domain Adaptation by Backpropagation + Proceedings of the 32nd International Conference on Machine Learning, 2015 + +## Environment + +- Python 2.7 +- PyTorch 0.3.1 + +## Result + +results of the default `params.py` + +| | MNIST (Source) | USPS (Target) | +| :--------------------------------: | :------------: | :-----------: | +| Source Classifier | 99.140000% | 83.978495% | +| DANN | | 97.634409% | + +## Credit + +- +- \ No newline at end of file diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core/dann.py b/core/dann.py new file mode 100644 index 0000000..80b6239 --- /dev/null +++ b/core/dann.py @@ -0,0 +1,96 @@ +"""Train dann.""" + +import torch +import torch.nn as nn +import torch.optim as optim + +import params +from utils import make_variable, save_model +import numpy as np +from core.test import eval + +import torch.backends.cudnn as cudnn +cudnn.benchmark = True + +def train_dann(dann, src_data_loader, tgt_data_loader, tgt_data_loader_eval): + """Train dann.""" + #################### + # 1. setup network # + #################### + + # set train state for Dropout and BN layers + dann.train() + + # setup criterion and optimizer + optimizer = optim.Adam(dann.parameters(), lr=params.lr) + + criterion = nn.NLLLoss() + + for p in dann.parameters(): + p.requires_grad = True + + #################### + # 2. train network # + #################### + + # prepare domain label + label_src = make_variable(torch.zeros(params.batch_size).long()) # source 0 + label_tgt = make_variable(torch.ones(params.batch_size).long()) # target 1 + + for epoch in range(params.num_epochs): + # zip source and target data pair + len_dataloader = min(len(src_data_loader), len(tgt_data_loader)) + data_zip = enumerate(zip(src_data_loader, tgt_data_loader)) + for step, ((images_src, class_src), (images_tgt, _)) in data_zip: + + p = float(step + epoch * len_dataloader) / params.num_epochs / len_dataloader + alpha = 2. / (1. + np.exp(-10 * p)) - 1 + + # make images variable + class_src = make_variable(class_src) + images_src = make_variable(images_src) + images_tgt = make_variable(images_tgt) + + # zero gradients for optimizer + optimizer.zero_grad() + + # train on source domain + src_class_output, src_domain_output = dann(input_data=images_src, alpha=alpha) + src_loss_class = criterion(src_class_output, class_src) + src_loss_domain = criterion(src_domain_output, label_src) + + # train on target domain + _, tgt_domain_output = dann(input_data=images_tgt, alpha=alpha) + tgt_loss_domain = criterion(tgt_domain_output, label_tgt) + + loss = src_loss_class + src_loss_domain + tgt_loss_domain + + # optimize dann + loss.backward() + optimizer.step() + + # print step info + if ((step + 1) % params.log_step == 0): + print("Epoch [{}/{}] Step [{}/{}]: src_loss_class={}, src_loss_domain={}, tgt_loss_domain={}, loss={}" + .format(epoch + 1, + params.num_epochs, + step + 1, + len_dataloader, + src_loss_class.data[0], + src_loss_domain.data[0], + tgt_loss_domain.data[0], + loss.data[0])) + + # eval model on test set + if ((epoch + 1) % params.eval_step == 0): + eval(dann, tgt_data_loader_eval) + dann.train() + + # save model parameters + if ((epoch + 1) % params.save_step == 0): + save_model(dann, params.src_dataset + '-' + params.tgt_dataset + "-dann-{}.pt".format(epoch + 1)) + + # save final model + save_model(dann, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") + + return dann \ No newline at end of file diff --git a/core/pretrain.py b/core/pretrain.py new file mode 100644 index 0000000..7d80fd4 --- /dev/null +++ b/core/pretrain.py @@ -0,0 +1,65 @@ +"""Train classifier for source dataset.""" + +import torch.nn as nn +import torch.optim as optim + +import params +from utils import make_variable, save_model +from core.test import eval_src + +def train_src(model, data_loader): + """Train classifier for source domain.""" + #################### + # 1. setup network # + #################### + + # set train state for Dropout and BN layers + model.train() + + # setup criterion and optimizer + optimizer = optim.Adam(model.parameters(), lr=params.lr) + loss_class = nn.NLLLoss() + + #################### + # 2. train network # + #################### + + for epoch in range(params.num_epochs_src): + for step, (images, labels) in enumerate(data_loader): + # make images and labels variable + images = make_variable(images) + labels = make_variable(labels.squeeze_()) + + # zero gradients for optimizer + optimizer.zero_grad() + + # compute loss for critic + preds = model(images) + loss = loss_class(preds, labels) + + # optimize source classifier + loss.backward() + optimizer.step() + + # print step info + if ((step + 1) % params.log_step_src == 0): + print("Epoch [{}/{}] Step [{}/{}]: loss={}" + .format(epoch + 1, + params.num_epochs_src, + step + 1, + len(data_loader), + loss.data[0])) + + # eval model on test set + if ((epoch + 1) % params.eval_step_src == 0): + eval_src(model, data_loader) + model.train() + + # save model parameters + if ((epoch + 1) % params.save_step_src == 0): + save_model(model, params.src_dataset + "-source-classifier-{}.pt".format(epoch + 1)) + + # save final model + save_model(model, params.src_dataset + "-source-classifier-final.pt") + + return model \ No newline at end of file diff --git a/core/test.py b/core/test.py new file mode 100644 index 0000000..020feaa --- /dev/null +++ b/core/test.py @@ -0,0 +1,159 @@ +import os +import torch.backends.cudnn as cudnn +import torch.utils.data +from torch.autograd import Variable +import torch.nn as nn + +import params +from utils import make_variable + +def test(dataloader, epoch): + + if torch.cuda.is_available(): + cuda = True + cudnn.benchmark = True + + """ training """ + + my_net = torch.load(os.path.join( + params.model_root, params.src_dataset + '_' + params.tgt_dataset + '_model_epoch_' + str(epoch) + '.pth' + )) + my_net = my_net.eval() + + if cuda: + my_net = my_net.cuda() + + len_dataloader = len(dataloader) + data_target_iter = iter(dataloader) + + i = 0 + n_total = 0 + n_correct = 0 + + while i < len_dataloader: + + # test model using target data + data_target = data_target_iter.next() + t_img, t_label = data_target + + batch_size = len(t_label) + + input_img = torch.FloatTensor(batch_size, 3, params.image_size, params.image_size) + class_label = torch.LongTensor(batch_size) + + if cuda: + t_img = t_img.cuda() + t_label = t_label.cuda() + input_img = input_img.cuda() + class_label = class_label.cuda() + + input_img.resize_as_(t_img).copy_(t_img) + class_label.resize_as_(t_label).copy_(t_label) + inputv_img = Variable(input_img) + classv_label = Variable(class_label) + + class_output, _ = my_net(input_data=inputv_img, alpha=params.alpha) + pred = class_output.data.max(1, keepdim=True)[1] + n_correct += pred.eq(classv_label.data.view_as(pred)).cpu().sum() + n_total += batch_size + + i += 1 + + accu = n_correct * 1.0 / n_total + + print 'epoch: %d, accuracy: %f' % (epoch, accu) + + +def test_from_save(model, saved_model, data_loader): + """Evaluate classifier for source domain.""" + # set eval state for Dropout and BN layers + classifier = model.load_state_dict(torch.load(saved_model)) + classifier.eval() + + # init loss and accuracy + loss = 0.0 + acc = 0.0 + + # set loss function + criterion = nn.NLLLoss() + + # evaluate network + for (images, labels) in data_loader: + images = make_variable(images, volatile=True) + labels = make_variable(labels) #labels = labels.squeeze(1) + preds = classifier(images) + + criterion(preds, labels) + + loss += criterion(preds, labels).data[0] + + pred_cls = preds.data.max(1)[1] + acc += pred_cls.eq(labels.data).cpu().sum() + + loss /= len(data_loader) + acc /= len(data_loader.dataset) + + print("Avg Loss = {}, Avg Accuracy = {:2%}".format(loss, acc)) + + +def eval(model, data_loader): + """Evaluate model for dataset.""" + # set eval state for Dropout and BN layers + model.eval() + + # init loss and accuracy + loss = 0.0 + acc = 0.0 + + # set loss function + criterion = nn.NLLLoss() + + # evaluate network + for (images, labels) in data_loader: + images = make_variable(images, volatile=True) + labels = make_variable(labels) #labels = labels.squeeze(1) + preds, _ = model(images, alpha=0) + + criterion(preds, labels) + + loss += criterion(preds, labels).data[0] + + pred_cls = preds.data.max(1)[1] + acc += pred_cls.eq(labels.data).cpu().sum() + + loss /= len(data_loader) + acc /= len(data_loader.dataset) + + print("Avg Loss = {}, Avg Accuracy = {:2%}".format(loss, acc)) + + +def eval_src(model, data_loader): + """Evaluate classifier for source domain.""" + # set eval state for Dropout and BN layers + model.eval() + + # init loss and accuracy + loss = 0.0 + acc = 0.0 + + # set loss function + criterion = nn.NLLLoss() + + # evaluate network + for (images, labels) in data_loader: + images = make_variable(images, volatile=True) + labels = make_variable(labels) #labels = labels.squeeze(1) + preds = model(images) + + criterion(preds, labels) + + loss += criterion(preds, labels).data[0] + + pred_cls = preds.data.max(1)[1] + acc += pred_cls.eq(labels.data).cpu().sum() + + + loss /= len(data_loader) + acc /= len(data_loader.dataset) + + print("Avg Loss = {}, Avg Accuracy = {:2%}".format(loss, acc)) \ No newline at end of file diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000..9288485 --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1,5 @@ +from .mnist import get_mnist +from .mnistm import get_mnistm +from .svhn import get_svhn + +__all__ = (get_mnist, get_svhn, get_mnistm) diff --git a/datasets/mnist.py b/datasets/mnist.py new file mode 100644 index 0000000..30dbe56 --- /dev/null +++ b/datasets/mnist.py @@ -0,0 +1,31 @@ +"""Dataset setting and data loader for MNIST.""" + + +import torch +from torchvision import datasets, transforms +import os + +import params + +def get_mnist(train): + """Get MNIST datasets loader.""" + # image pre-processing + pre_process = transforms.Compose([transforms.ToTensor(), + transforms.Normalize( + mean=params.dataset_mean, + std=params.dataset_std)]) + + # datasets and data loader + mnist_dataset = datasets.MNIST(root=os.path.join(params.dataset_root,'mnist'), + train=train, + transform=pre_process, + download=False) + + + mnist_data_loader = torch.utils.data.DataLoader( + dataset=mnist_dataset, + batch_size=params.batch_size, + shuffle=True, + drop_last=True) + + return mnist_data_loader \ No newline at end of file diff --git a/datasets/mnistm.py b/datasets/mnistm.py new file mode 100644 index 0000000..f8ac0dd --- /dev/null +++ b/datasets/mnistm.py @@ -0,0 +1,70 @@ +"""Dataset setting and data loader for MNIST_M.""" + +import torch +from torchvision import datasets, transforms +import torch.utils.data as data +from PIL import Image +import os +import params + +class GetLoader(data.Dataset): + def __init__(self, data_root, data_list, transform=None): + self.root = data_root + self.transform = transform + + f = open(data_list, 'r') + data_list = f.readlines() + f.close() + + self.n_data = len(data_list) + + self.img_paths = [] + self.img_labels = [] + + for data in data_list: + self.img_paths.append(data[:-3]) + self.img_labels.append(data[-2]) + + def __getitem__(self, item): + img_paths, labels = self.img_paths[item], self.img_labels[item] + imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB') + + if self.transform is not None: + imgs = self.transform(imgs) + labels = int(labels) + + return imgs, labels + + def __len__(self): + return self.n_data + +def get_mnistm(train): + """Get MNISTM datasets loader.""" + # image pre-processing + pre_process = transforms.Compose([transforms.Resize(params.image_size), + transforms.ToTensor(), + transforms.Normalize( + mean=params.dataset_mean, + std=params.dataset_std)]) + + # datasets and data_loader + if train: + train_list = os.path.join(params.dataset_root, 'mnist_m','mnist_m_train_labels.txt') + mnistm_dataset = GetLoader( + data_root=os.path.join(params.dataset_root, 'mnist_m', 'mnist_m_train'), + data_list=train_list, + transform=pre_process) + else: + train_list = os.path.join(params.dataset_root, 'mnist_m', 'mnist_m_test_labels.txt') + mnistm_dataset = GetLoader( + data_root=os.path.join(params.dataset_root, 'mnist_m', 'mnist_m_test'), + data_list=train_list, + transform=pre_process) + + mnistm_dataloader = torch.utils.data.DataLoader( + dataset=mnistm_dataset, + batch_size=params.batch_size, + shuffle=True, + num_workers=8) + + return mnistm_dataloader \ No newline at end of file diff --git a/datasets/svhn.py b/datasets/svhn.py new file mode 100644 index 0000000..5b21de3 --- /dev/null +++ b/datasets/svhn.py @@ -0,0 +1,39 @@ +"""Dataset setting and data loader for SVHN.""" + + +import torch +from torchvision import datasets, transforms +import os + +import params + + +def get_svhn(train): + """Get SVHN datasets loader.""" + # image pre-processing + pre_process = transforms.Compose([transforms.Grayscale(), + transforms.Resize(params.image_size), + transforms.ToTensor(), + transforms.Normalize( + mean=params.dataset_mean, + std=params.dataset_std)]) + + # datasets and data loader + if train: + svhn_dataset = datasets.SVHN(root=os.path.join(params.dataset_root,'svhn'), + split='train', + transform=pre_process, + download=True) + else: + svhn_dataset = datasets.SVHN(root=os.path.join(params.dataset_root,'svhn'), + split='test', + transform=pre_process, + download=True) + + svhn_data_loader = torch.utils.data.DataLoader( + dataset=svhn_dataset, + batch_size=params.batch_size, + shuffle=True, + drop_last=True) + + return svhn_data_loader diff --git a/main.py b/main.py new file mode 100644 index 0000000..eb1de77 --- /dev/null +++ b/main.py @@ -0,0 +1,50 @@ +from models.model import CNNModel +from models.classifier import Classifier + +from core.dann import train_dann +from core.test import eval, eval_src +from core.pretrain import train_src + +import params +from utils import get_data_loader, init_model, init_random_seed + +# init random seed +init_random_seed(params.manual_seed) + +# load dataset +src_data_loader = get_data_loader(params.src_dataset) +src_data_loader_eval = get_data_loader(params.src_dataset, train=False) +tgt_data_loader = get_data_loader(params.tgt_dataset) +tgt_data_loader_eval = get_data_loader(params.tgt_dataset, train=False) + +# load source classifier +src_classifier = init_model(net=Classifier(), restore=params.src_classifier_restore) + +# train source model +print("=== Training classifier for source domain ===") + +if not (src_classifier.restored and params.src_model_trained): + src_classifier = train_src(src_classifier, src_data_loader) + +# eval source model on both source and target domain +print("=== Evaluating source classifier for source domain ===") +eval_src(src_classifier, src_data_loader_eval) +print("=== Evaluating source classifier for target domain ===") +eval_src(src_classifier, tgt_data_loader_eval) + +# load dann model +dann = init_model(net=CNNModel(), restore=params.dann_restore) + +# train dann model +print("=== Training dann model ===") + +if not (dann.restored and params.dann_restore): + dann = train_dann(dann, src_data_loader, tgt_data_loader, tgt_data_loader_eval) + +# eval dann model +print("=== Evaluating dann for source domain ===") +eval(dann, src_data_loader_eval) +print("=== Evaluating dann for target domain ===") +eval(dann, tgt_data_loader_eval) + +print('done') \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/classifier.py b/models/classifier.py new file mode 100644 index 0000000..8390d23 --- /dev/null +++ b/models/classifier.py @@ -0,0 +1,39 @@ +"""Classifier for source domain""" + +import torch.nn as nn + +class Classifier(nn.Module): + + def __init__(self): + super(Classifier, self).__init__() + self.restored = False + + self.feature = nn.Sequential() + self.feature.add_module('f_conv1', nn.Conv2d(1, 64, kernel_size=5)) + self.feature.add_module('f_bn1', nn.BatchNorm2d(64)) + self.feature.add_module('f_pool1', nn.MaxPool2d(2)) + self.feature.add_module('f_relu1', nn.ReLU(True)) + self.feature.add_module('f_conv2', nn.Conv2d(64, 50, kernel_size=5)) + self.feature.add_module('f_bn2', nn.BatchNorm2d(50)) + self.feature.add_module('f_drop1', nn.Dropout2d()) + self.feature.add_module('f_pool2', nn.MaxPool2d(2)) + self.feature.add_module('f_relu2', nn.ReLU(True)) + + self.class_classifier = nn.Sequential() + self.class_classifier.add_module('c_fc1', nn.Linear(50 * 4 * 4, 100)) + self.class_classifier.add_module('c_bn1', nn.BatchNorm2d(100)) + self.class_classifier.add_module('c_relu1', nn.ReLU(True)) + self.class_classifier.add_module('c_drop1', nn.Dropout2d()) + self.class_classifier.add_module('c_fc2', nn.Linear(100, 100)) + self.class_classifier.add_module('c_bn2', nn.BatchNorm2d(100)) + self.class_classifier.add_module('c_relu2', nn.ReLU(True)) + self.class_classifier.add_module('c_fc3', nn.Linear(100, 10)) + self.class_classifier.add_module('c_softmax', nn.LogSoftmax(dim=1)) + + def forward(self, input_data): + input_data = input_data.expand(input_data.data.shape[0], 1, 28, 28) + feature = self.feature(input_data) + feature = feature.view(-1, 50 * 4 * 4) + class_output = self.class_classifier(feature) + + return class_output diff --git a/models/functions.py b/models/functions.py new file mode 100644 index 0000000..fc0eab1 --- /dev/null +++ b/models/functions.py @@ -0,0 +1,18 @@ +from torch.autograd import Function + + +class ReverseLayerF(Function): + + @staticmethod + def forward(ctx, x, alpha): + ctx.alpha = alpha + + return x.view_as(x) + + @staticmethod + def backward(ctx, grad_output): + output = grad_output.neg() * ctx.alpha + + return output, None + + diff --git a/models/model.py b/models/model.py new file mode 100644 index 0000000..86394f7 --- /dev/null +++ b/models/model.py @@ -0,0 +1,50 @@ +"""DANN model.""" + +import torch.nn as nn +from functions import ReverseLayerF + + +class CNNModel(nn.Module): + + def __init__(self): + super(CNNModel, self).__init__() + self.restored = False + + self.feature = nn.Sequential() + self.feature.add_module('f_conv1', nn.Conv2d(1, 64, kernel_size=5)) + self.feature.add_module('f_bn1', nn.BatchNorm2d(64)) + self.feature.add_module('f_pool1', nn.MaxPool2d(2)) + self.feature.add_module('f_relu1', nn.ReLU(True)) + self.feature.add_module('f_conv2', nn.Conv2d(64, 50, kernel_size=5)) + self.feature.add_module('f_bn2', nn.BatchNorm2d(50)) + self.feature.add_module('f_drop1', nn.Dropout2d()) + self.feature.add_module('f_pool2', nn.MaxPool2d(2)) + self.feature.add_module('f_relu2', nn.ReLU(True)) + + self.class_classifier = nn.Sequential() + self.class_classifier.add_module('c_fc1', nn.Linear(50 * 4 * 4, 100)) + self.class_classifier.add_module('c_bn1', nn.BatchNorm2d(100)) + self.class_classifier.add_module('c_relu1', nn.ReLU(True)) + self.class_classifier.add_module('c_drop1', nn.Dropout2d()) + self.class_classifier.add_module('c_fc2', nn.Linear(100, 100)) + self.class_classifier.add_module('c_bn2', nn.BatchNorm2d(100)) + self.class_classifier.add_module('c_relu2', nn.ReLU(True)) + self.class_classifier.add_module('c_fc3', nn.Linear(100, 10)) + self.class_classifier.add_module('c_softmax', nn.LogSoftmax(dim=1)) + + self.domain_classifier = nn.Sequential() + self.domain_classifier.add_module('d_fc1', nn.Linear(50 * 4 * 4, 100)) + self.domain_classifier.add_module('d_bn1', nn.BatchNorm2d(100)) + self.domain_classifier.add_module('d_relu1', nn.ReLU(True)) + self.domain_classifier.add_module('d_fc2', nn.Linear(100, 2)) + self.domain_classifier.add_module('d_softmax', nn.LogSoftmax(dim=1)) + + def forward(self, input_data, alpha): + input_data = input_data.expand(input_data.data.shape[0], 1, 28, 28) + feature = self.feature(input_data) + feature = feature.view(-1, 50 * 4 * 4) + reverse_feature = ReverseLayerF.apply(feature, alpha) + class_output = self.class_classifier(feature) + domain_output = self.domain_classifier(reverse_feature) + + return class_output, domain_output diff --git a/params.py b/params.py new file mode 100644 index 0000000..8868212 --- /dev/null +++ b/params.py @@ -0,0 +1,43 @@ +"""Params for DANN.""" + +import os + +# params for path +dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) +model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN')) + +# params for datasets and data loader +dataset_mean_value = 0.5 +dataset_std_value = 0.5 +dataset_mean = (dataset_mean_value, dataset_mean_value, dataset_mean_value) +dataset_std = (dataset_std_value, dataset_std_value, dataset_std_value) +batch_size = 128 +image_size = 28 + +# params for source dataset +src_dataset = "SVHN" +src_model_trained = True +src_classifier_restore = os.path.join(model_root,src_dataset + '-source-classifier-final.pt') + +# params for target dataset +tgt_dataset = "MNIST" +tgt_model_trained = True +dann_restore = os.path.join(model_root , src_dataset + '-' + tgt_dataset + '-dann-final.pt') + +# params for pretrain +num_epochs_src = 100 +log_step_src = 10 +save_step_src = 20 +eval_step_src = 20 + +# params for training dann +num_epochs = 400 +log_step = 50 +save_step = 50 +eval_step = 20 + +manual_seed = 8888 +alpha = 0 + +# params for optimizing models +lr = 2e-4 \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..d8279fe --- /dev/null +++ b/utils.py @@ -0,0 +1,96 @@ +"""Utilities for ADDA.""" + +import os +import random + +import torch +import torch.backends.cudnn as cudnn +from torch.autograd import Variable + +import params +from datasets import get_mnist, get_mnistm, get_svhn + + +def make_variable(tensor, volatile=False): + """Convert Tensor to Variable.""" + if torch.cuda.is_available(): + tensor = tensor.cuda() + return Variable(tensor, volatile=volatile) + + +def make_cuda(tensor): + """Use CUDA if it's available.""" + if torch.cuda.is_available(): + tensor = tensor.cuda() + return tensor + + +def denormalize(x, std, mean): + """Invert normalization, and then convert array into image.""" + out = x * std + mean + return out.clamp(0, 1) + + +def init_weights(layer): + """Init weights for layers w.r.t. the original paper.""" + layer_name = layer.__class__.__name__ + if layer_name.find("Conv") != -1: + layer.weight.data.normal_(0.0, 0.02) + elif layer_name.find("BatchNorm") != -1: + layer.weight.data.normal_(1.0, 0.02) + layer.bias.data.fill_(0) + + +def init_random_seed(manual_seed): + """Init random seed.""" + seed = None + if manual_seed is None: + seed = random.randint(1, 10000) + else: + seed = manual_seed + print("use random seed: {}".format(seed)) + random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def get_data_loader(name, train=True): + """Get data loader by name.""" + if name == "MNIST": + return get_mnist(train) + elif name == "MNISTM": + return get_mnistm(train) + elif name == "SVHN": + return get_svhn(train) + + +def init_model(net, restore): + """Init models with cuda and weights.""" + # init weights of model + net.apply(init_weights) + + # restore model weights + if restore is not None and os.path.exists(restore): + net.load_state_dict(torch.load(restore)) + net.restored = True + print("Restore model from: {}".format(os.path.abspath(restore))) + else: + print("No trained model, train from scratch.") + + # check if cuda is available + if torch.cuda.is_available(): + cudnn.benchmark = True + net.cuda() + + return net + + +def save_model(net, filename): + """Save trained model.""" + if not os.path.exists(params.model_root): + os.makedirs(params.model_root) + torch.save(net.state_dict(), + os.path.join(params.model_root, filename)) + print("save pretrained model to: {}".format(os.path.join(params.model_root, + filename))) \ No newline at end of file