You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
92 lines
2.7 KiB
92 lines
2.7 KiB
import os
|
|
import random
|
|
|
|
import torch
|
|
import torch.backends.cudnn as cudnn
|
|
from torch.autograd import Variable
|
|
|
|
from datasets import get_mnist, get_mnistm, get_svhn
|
|
from datasets.office import get_office
|
|
from datasets.officecaltech import get_officecaltech
|
|
|
|
|
|
def make_cuda(tensor):
|
|
"""Use CUDA if it's available."""
|
|
if torch.cuda.is_available():
|
|
tensor = tensor.cuda()
|
|
return tensor
|
|
|
|
|
|
def denormalize(x, std, mean):
|
|
"""Invert normalization, and then convert array into image."""
|
|
out = x * std + mean
|
|
return out.clamp(0, 1)
|
|
|
|
|
|
def init_weights(layer):
|
|
"""Init weights for layers w.r.t. the original paper."""
|
|
layer_name = layer.__class__.__name__
|
|
if layer_name.find("Conv") != -1:
|
|
layer.weight.data.normal_(0.0, 0.02)
|
|
elif layer_name.find("BatchNorm") != -1:
|
|
layer.weight.data.normal_(1.0, 0.02)
|
|
layer.bias.data.fill_(0)
|
|
|
|
|
|
def init_random_seed(manual_seed):
|
|
"""Init random seed."""
|
|
seed = None
|
|
if manual_seed is None:
|
|
seed = random.randint(1, 10000)
|
|
else:
|
|
seed = manual_seed
|
|
print("use random seed: {}".format(seed))
|
|
random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def get_data_loader(name, dataset_root, batch_size, train=True):
|
|
"""Get data loader by name."""
|
|
if name == "mnist":
|
|
return get_mnist(dataset_root, batch_size, train)
|
|
elif name == "mnistm":
|
|
return get_mnistm(dataset_root, batch_size, train)
|
|
elif name == "svhn":
|
|
return get_svhn(dataset_root, batch_size, train)
|
|
elif name == "amazon31":
|
|
return get_office(dataset_root, batch_size, 'amazon')
|
|
elif name == "webcam31":
|
|
return get_office(dataset_root, batch_size, 'webcam')
|
|
elif name == "webcam10":
|
|
return get_officecaltech(dataset_root, batch_size, 'webcam')
|
|
|
|
|
|
def init_model(net, restore):
|
|
"""Init models with cuda and weights."""
|
|
# init weights of model
|
|
# net.apply(init_weights)
|
|
|
|
# restore model weights
|
|
if restore is not None and os.path.exists(restore):
|
|
net.load_state_dict(torch.load(restore))
|
|
net.restored = True
|
|
print("Restore model from: {}".format(os.path.abspath(restore)))
|
|
else:
|
|
print("No trained model, train from scratch.")
|
|
|
|
# check if cuda is available
|
|
if torch.cuda.is_available():
|
|
cudnn.benchmark = True
|
|
net.cuda()
|
|
|
|
return net
|
|
|
|
|
|
def save_model(net, model_root, filename):
|
|
"""Save trained model."""
|
|
if not os.path.exists(model_root):
|
|
os.makedirs(model_root)
|
|
torch.save(net.state_dict(), os.path.join(model_root, filename))
|
|
print("save pretrained model to: {}".format(os.path.join(model_root, filename)))
|