You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
96 lines
2.5 KiB
96 lines
2.5 KiB
7 years ago
|
"""Utilities for ADDA."""
|
||
|
|
||
|
import os
|
||
|
import random
|
||
|
|
||
|
import torch
|
||
|
import torch.backends.cudnn as cudnn
|
||
|
from torch.autograd import Variable
|
||
|
|
||
|
import params
|
||
|
from datasets import get_mnist, get_mnistm, get_svhn
|
||
|
|
||
|
|
||
|
def make_variable(tensor, volatile=False):
|
||
|
"""Convert Tensor to Variable."""
|
||
|
if torch.cuda.is_available():
|
||
|
tensor = tensor.cuda()
|
||
|
return Variable(tensor, volatile=volatile)
|
||
|
|
||
|
|
||
|
def make_cuda(tensor):
|
||
|
"""Use CUDA if it's available."""
|
||
|
if torch.cuda.is_available():
|
||
|
tensor = tensor.cuda()
|
||
|
return tensor
|
||
|
|
||
|
|
||
|
def denormalize(x, std, mean):
|
||
|
"""Invert normalization, and then convert array into image."""
|
||
|
out = x * std + mean
|
||
|
return out.clamp(0, 1)
|
||
|
|
||
|
|
||
|
def init_weights(layer):
|
||
|
"""Init weights for layers w.r.t. the original paper."""
|
||
|
layer_name = layer.__class__.__name__
|
||
|
if layer_name.find("Conv") != -1:
|
||
|
layer.weight.data.normal_(0.0, 0.02)
|
||
|
elif layer_name.find("BatchNorm") != -1:
|
||
|
layer.weight.data.normal_(1.0, 0.02)
|
||
|
layer.bias.data.fill_(0)
|
||
|
|
||
|
|
||
|
def init_random_seed(manual_seed):
|
||
|
"""Init random seed."""
|
||
|
seed = None
|
||
|
if manual_seed is None:
|
||
|
seed = random.randint(1, 10000)
|
||
|
else:
|
||
|
seed = manual_seed
|
||
|
print("use random seed: {}".format(seed))
|
||
|
random.seed(seed)
|
||
|
torch.manual_seed(seed)
|
||
|
if torch.cuda.is_available():
|
||
|
torch.cuda.manual_seed_all(seed)
|
||
|
|
||
|
|
||
|
def get_data_loader(name, train=True):
|
||
|
"""Get data loader by name."""
|
||
|
if name == "MNIST":
|
||
|
return get_mnist(train)
|
||
|
elif name == "MNISTM":
|
||
|
return get_mnistm(train)
|
||
|
elif name == "SVHN":
|
||
|
return get_svhn(train)
|
||
|
|
||
|
|
||
|
def init_model(net, restore):
|
||
|
"""Init models with cuda and weights."""
|
||
|
# init weights of model
|
||
|
net.apply(init_weights)
|
||
|
|
||
|
# restore model weights
|
||
|
if restore is not None and os.path.exists(restore):
|
||
|
net.load_state_dict(torch.load(restore))
|
||
|
net.restored = True
|
||
|
print("Restore model from: {}".format(os.path.abspath(restore)))
|
||
|
else:
|
||
|
print("No trained model, train from scratch.")
|
||
|
|
||
|
# check if cuda is available
|
||
|
if torch.cuda.is_available():
|
||
|
cudnn.benchmark = True
|
||
|
net.cuda()
|
||
|
|
||
|
return net
|
||
|
|
||
|
|
||
|
def save_model(net, filename):
|
||
|
"""Save trained model."""
|
||
|
if not os.path.exists(params.model_root):
|
||
|
os.makedirs(params.model_root)
|
||
|
torch.save(net.state_dict(),
|
||
|
os.path.join(params.model_root, filename))
|
||
|
print("save pretrained model to: {}".format(os.path.join(params.model_root,
|
||
|
filename)))
|