Browse Source

reorganize file.

master
wogong 6 years ago
parent
commit
029a40e43e
  1. 29
      README.md
  2. 73
      core/dann.py
  3. 21
      core/pretrain.py
  4. 95
      core/test.py
  5. 19
      datasets/office.py
  6. 4
      experiments/mnist_mnistm.py
  7. 21
      experiments/office.py
  8. 81
      experiments/office31_10.py
  9. 4
      experiments/svhn_mnist.py
  10. 4
      models/model.py
  11. 14
      utils/utils.py

29
README.md

@ -15,31 +15,20 @@ A pytorch implementation for paper *[Unsupervised Domain Adaptation by Backpropa
## Note ## Note
- `Config()` 为针对特定任务的配置参数。 - `Config()` 为针对特定任务的配置参数。
- `MNISTmodel()` 完全按照论文中的结构,但是 feature 部分添加了 `Dropout2d()`,实验发现是否添
`Dropout2d()` 对于最后的性能影响很大。最后实验重现结果高于论文,因为使用了额外的技巧,这里
还有值得探究的地方。
- `MNISTmodel()` 完全按照论文中的结构,但是 feature 部分添加了 `Dropout2d()`,实验发现是否添加 `Dropout2d()` 对于最后的性能影响很大。最后实验重现结果高于论文,因为使用了额外的技巧,这里还有值得探究的地方。
- `SVHNmodel()` 无法理解论文中提出的结构,为自定义结构。最后实验重现结果完美。 - `SVHNmodel()` 无法理解论文中提出的结构,为自定义结构。最后实验重现结果完美。
## Result
| | MNIST-MNISTM | SVHN-MNIST | Amazon-Webcam |
| :-------------: | :------------: | :--------: | :--------: |
| Source Only | 0.5225 | 0.5490 | 0.6420 |
| DANN | 0.7666 | 0.7385 | 0.7300 |
| This Repo | 0.8400 | 0.7339 | 0.6428 |
- MNIST-MNISTM: `python mnist_mnistm.py` - MNIST-MNISTM: `python mnist_mnistm.py`
- SVHN-MNIST: `python svhn_mnist.py` - SVHN-MNIST: `python svhn_mnist.py`
- Amazon-Webcam: 没有复现成功
- Amazon-Webcam: `python office.py` 没有复现成功
## Other implementations
## Result
- authors(caffe) <https://github.com/ddtm/caffe>
- TensorFlow, <https://github.com/pumpikano/tf-dann>
- Theano, <https://github.com/shucunt/domain_adaptation>
- PyTorch, <https://github.com/fungtion/DANN>
- numpy, <https://github.com/GRAAL-Research/domain_adversarial_neural_network>
- lua, <https://github.com/gmarceaucaron/dann>
| | MNIST-MNISTM | SVHN-MNIST | Amazon-Webcam |Amazon-Webcam10 |
| :------------------: | :------------: | :--------: | :-----------: |:-------------: |
| Source Only | 0.5225 | 0.5490 | 0.6420 | 0. |
| DANN(paper) | 0.7666 | 0.7385 | 0.7300 | 0. |
| This Repo Source Only| - | - | - | 0. |
| This Repo | 0.8400 | 0.7339 | 0.6528 | 0. |
## Credit ## Credit

73
core/dann.py

@ -1,17 +1,19 @@
"""Train dann.""" """Train dann."""
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from utils import make_variable, save_model
import numpy as np
from core.test import eval, eval_src
from core.test import eval
from utils.utils import save_model
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
cudnn.benchmark = True cudnn.benchmark = True
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval):
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device):
"""Train dann.""" """Train dann."""
#################### ####################
# 1. setup network # # 1. setup network #
@ -24,20 +26,23 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
else: else:
print("training office task") print("training office task")
parameter_list = [
{"params": model.features.parameters(), "lr": 0.001},
{"params": model.fc.parameters(), "lr": 0.001},
{"params": model.bottleneck.parameters()},
{"params": model.classifier.parameters()},
{"params": model.discriminator.parameters()}
]
parameter_list = [{
"params": model.features.parameters(),
"lr": 0.001
}, {
"params": model.fc.parameters(),
"lr": 0.001
}, {
"params": model.bottleneck.parameters()
}, {
"params": model.classifier.parameters()
}, {
"params": model.discriminator.parameters()
}]
optimizer = optim.SGD(parameter_list, lr=0.01, momentum=0.9) optimizer = optim.SGD(parameter_list, lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
for p in model.parameters():
p.requires_grad = True
#################### ####################
# 2. train network # # 2. train network #
#################### ####################
@ -50,9 +55,9 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
data_zip = enumerate(zip(src_data_loader, tgt_data_loader)) data_zip = enumerate(zip(src_data_loader, tgt_data_loader))
for step, ((images_src, class_src), (images_tgt, _)) in data_zip: for step, ((images_src, class_src), (images_tgt, _)) in data_zip:
p = float(step + epoch * len_dataloader) / params.num_epochs / len_dataloader
p = float(step + epoch * len_dataloader) / \
params.num_epochs / len_dataloader
alpha = 2. / (1. + np.exp(-10 * p)) - 1 alpha = 2. / (1. + np.exp(-10 * p)) - 1
alpha = 2*alpha
if params.src_dataset == 'mnist' or params.tgt_dataset == 'mnist': if params.src_dataset == 'mnist' or params.tgt_dataset == 'mnist':
adjust_learning_rate(optimizer, p) adjust_learning_rate(optimizer, p)
@ -62,13 +67,13 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
# prepare domain label # prepare domain label
size_src = len(images_src) size_src = len(images_src)
size_tgt = len(images_tgt) size_tgt = len(images_tgt)
label_src = make_variable(torch.zeros(size_src).long()) # source 0
label_tgt = make_variable(torch.ones(size_tgt).long()) # target 1
label_src = torch.zeros(size_src).long().to(device) # source 0
label_tgt = torch.ones(size_tgt).long().to(device) # target 1
# make images variable # make images variable
class_src = make_variable(class_src)
images_src = make_variable(images_src)
images_tgt = make_variable(images_tgt)
class_src = class_src.to(device)
images_src = images_src.to(device)
images_tgt = images_tgt.to(device)
# zero gradients for optimizer # zero gradients for optimizer
optimizer.zero_grad() optimizer.zero_grad()
@ -90,32 +95,29 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
# print step info # print step info
if ((step + 1) % params.log_step == 0): if ((step + 1) % params.log_step == 0):
print("Epoch [{:4d}/{}] Step [{:2d}/{}]: src_loss_class={:.6f}, src_loss_domain={:.6f}, tgt_loss_domain={:.6f}, loss={:.6f}"
.format(epoch + 1,
params.num_epochs,
step + 1,
len_dataloader,
src_loss_class.data[0],
src_loss_domain.data[0],
tgt_loss_domain.data[0],
loss.data[0]))
# eval model on test set
print(
"Epoch [{:4d}/{}] Step [{:2d}/{}]: src_loss_class={:.6f}, src_loss_domain={:.6f}, tgt_loss_domain={:.6f}, loss={:.6f}"
.format(epoch + 1, params.num_epochs, step + 1, len_dataloader, src_loss_class.data.item(),
src_loss_domain.data.item(), tgt_loss_domain.data.item(), loss.data.item()))
# eval model
if ((epoch + 1) % params.eval_step == 0): if ((epoch + 1) % params.eval_step == 0):
print("eval on target domain") print("eval on target domain")
eval(model, tgt_data_loader)
eval(model, tgt_data_loader, device, flag='target')
print("eval on source domain") print("eval on source domain")
eval_src(model, src_data_loader)
eval(model, src_data_loader, device, flag='source')
# save model parameters # save model parameters
if ((epoch + 1) % params.save_step == 0): if ((epoch + 1) % params.save_step == 0):
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-{}.pt".format(epoch + 1))
save_model(model, params.model_root,
params.src_dataset + '-' + params.tgt_dataset + "-dann-{}.pt".format(epoch + 1))
# save final model # save final model
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt")
return model return model
def adjust_learning_rate(optimizer, p): def adjust_learning_rate(optimizer, p):
lr_0 = 0.01 lr_0 = 0.01
alpha = 10 alpha = 10
@ -124,6 +126,7 @@ def adjust_learning_rate(optimizer, p):
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = lr param_group['lr'] = lr
def adjust_learning_rate_office(optimizer, p): def adjust_learning_rate_office(optimizer, p):
lr_0 = 0.001 lr_0 = 0.001
alpha = 10 alpha = 10

21
core/pretrain.py

@ -3,10 +3,11 @@
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from utils import make_variable, save_model
from core.test import eval_src
from utils.utils import save_model
from core.test import eval
def train_src(model, params, data_loader):
def train_src(model, params, data_loader, device):
"""Train classifier for source domain.""" """Train classifier for source domain."""
#################### ####################
# 1. setup network # # 1. setup network #
@ -26,8 +27,8 @@ def train_src(model, params, data_loader):
for epoch in range(params.num_epochs_src): for epoch in range(params.num_epochs_src):
for step, (images, labels) in enumerate(data_loader): for step, (images, labels) in enumerate(data_loader):
# make images and labels variable # make images and labels variable
images = make_variable(images)
labels = make_variable(labels.squeeze_())
images = images.to(device)
labels = labels.squeeze_().to(device)
# zero gradients for optimizer # zero gradients for optimizer
optimizer.zero_grad() optimizer.zero_grad()
@ -42,16 +43,12 @@ def train_src(model, params, data_loader):
# print step info # print step info
if ((step + 1) % params.log_step_src == 0): if ((step + 1) % params.log_step_src == 0):
print("Epoch [{}/{}] Step [{}/{}]: loss={}"
.format(epoch + 1,
params.num_epochs_src,
step + 1,
len(data_loader),
loss.data[0]))
print("Epoch [{}/{}] Step [{}/{}]: loss={}".format(epoch + 1, params.num_epochs_src, step + 1,
len(data_loader), loss.data[0]))
# eval model on test set # eval model on test set
if ((epoch + 1) % params.eval_step_src == 0): if ((epoch + 1) % params.eval_step_src == 0):
eval_src(model, data_loader)
eval(model, data_loader, flag='source')
model.train() model.train()
# save model parameters # save model parameters

95
core/test.py

@ -1,9 +1,8 @@
import torch.utils.data import torch.utils.data
import torch.nn as nn import torch.nn as nn
from utils import make_variable
def test_from_save(model, saved_model, data_loader):
def test_from_save(model, saved_model, data_loader, device):
"""Evaluate classifier for source domain.""" """Evaluate classifier for source domain."""
# set eval state for Dropout and BN layers # set eval state for Dropout and BN layers
classifier = model.load_state_dict(torch.load(saved_model)) classifier = model.load_state_dict(torch.load(saved_model))
@ -18,13 +17,13 @@ def test_from_save(model, saved_model, data_loader):
# evaluate network # evaluate network
for (images, labels) in data_loader: for (images, labels) in data_loader:
images = make_variable(images, volatile=True)
labels = make_variable(labels) #labels = labels.squeeze(1)
images = images.to(device)
labels = labels.to(device) #labels = labels.squeeze(1)
preds = classifier(images) preds = classifier(images)
criterion(preds, labels) criterion(preds, labels)
loss += criterion(preds, labels).data[0]
loss += criterion(preds, labels).data.item()
pred_cls = preds.data.max(1)[1] pred_cls = preds.data.max(1)[1]
acc += pred_cls.eq(labels.data).cpu().sum() acc += pred_cls.eq(labels.data).cpu().sum()
@ -34,43 +33,8 @@ def test_from_save(model, saved_model, data_loader):
print("Avg Loss = {}, Avg Accuracy = {:.2%}".format(loss, acc)) print("Avg Loss = {}, Avg Accuracy = {:.2%}".format(loss, acc))
def eval(model, data_loader):
"""Evaluate model for dataset."""
# set eval state for Dropout and BN layers
model.eval()
# init loss and accuracy
loss = 0.0
acc = 0.0
acc_domain = 0.0
# set loss function
criterion = nn.CrossEntropyLoss()
# evaluate network
for (images, labels) in data_loader:
images = make_variable(images, volatile=True)
labels = make_variable(labels) #labels = labels.squeeze(1)
size_tgt = len(labels)
labels_domain = make_variable(torch.ones(size_tgt).long())
preds, domain = model(images, alpha=0)
loss += criterion(preds, labels).data[0]
pred_cls = preds.data.max(1)[1]
pred_domain = domain.data.max(1)[1]
acc += pred_cls.eq(labels.data).cpu().sum()
acc_domain += pred_domain.eq(labels_domain.data).cpu().sum()
loss /= len(data_loader)
acc /= len(data_loader.dataset)
acc_domain /= len(data_loader.dataset)
print("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_domain))
def eval_src(model, data_loader):
def eval(model, data_loader, device, flag):
"""Evaluate model for dataset.""" """Evaluate model for dataset."""
# set eval state for Dropout and BN layers # set eval state for Dropout and BN layers
model.eval() model.eval()
@ -85,54 +49,25 @@ def eval_src(model, data_loader):
# evaluate network # evaluate network
for (images, labels) in data_loader: for (images, labels) in data_loader:
images = make_variable(images, volatile=True)
labels = make_variable(labels) #labels = labels.squeeze(1)
size_tgt = len(labels)
labels_domain = make_variable(torch.zeros(size_tgt).long())
images = images.to(device)
labels = labels.to(device) #labels = labels.squeeze(1)
size = len(labels)
if flag == 'target':
labels_domain = torch.ones(size).long().to(device)
else:
labels_domain = torch.zeros(size).long().to(device)
preds, domain = model(images, alpha=0) preds, domain = model(images, alpha=0)
loss += criterion(preds, labels).data[0]
loss += criterion(preds, labels).data.item()
pred_cls = preds.data.max(1)[1] pred_cls = preds.data.max(1)[1]
pred_domain = domain.data.max(1)[1] pred_domain = domain.data.max(1)[1]
acc += pred_cls.eq(labels.data).cpu().sum()
acc_domain += pred_domain.eq(labels_domain.data).cpu().sum()
acc += pred_cls.eq(labels.data).sum().item()
acc_domain += pred_domain.eq(labels_domain.data).sum().item()
loss /= len(data_loader) loss /= len(data_loader)
acc /= len(data_loader.dataset) acc /= len(data_loader.dataset)
acc_domain /= len(data_loader.dataset) acc_domain /= len(data_loader.dataset)
print("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_domain)) print("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_domain))
def eval_src_(model, data_loader):
"""Evaluate classifier for source domain."""
# set eval state for Dropout and BN layers
model.eval()
# init loss and accuracy
loss = 0.0
acc = 0.0
# set loss function
criterion = nn.NLLLoss()
# evaluate network
for (images, labels) in data_loader:
images = make_variable(images, volatile=True)
labels = make_variable(labels) #labels = labels.squeeze(1)
preds = model(images)
criterion(preds, labels)
loss += criterion(preds, labels).data[0]
pred_cls = preds.data.max(1)[1]
acc += pred_cls.eq(labels.data).cpu().sum()
loss /= len(data_loader)
acc /= len(data_loader.dataset)
print("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}".format(loss, acc))

19
datasets/office.py

@ -1,30 +1,25 @@
"""Dataset setting and data loader for Office.""" """Dataset setting and data loader for Office."""
import os
import torch import torch
from torchvision import datasets, transforms from torchvision import datasets, transforms
import torch.utils.data as data import torch.utils.data as data
import os
def get_office(dataset_root, batch_size, category): def get_office(dataset_root, batch_size, category):
"""Get Office datasets loader.""" """Get Office datasets loader."""
# image pre-processing # image pre-processing
pre_process = transforms.Compose([transforms.Resize(227),
pre_process = transforms.Compose([
transforms.Resize(227),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)])
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
# datasets and data_loader # datasets and data_loader
office_dataset = datasets.ImageFolder( office_dataset = datasets.ImageFolder(
os.path.join(dataset_root, 'office', category, 'images'),
transform=pre_process)
os.path.join(dataset_root, 'office', category, 'images'), transform=pre_process)
office_dataloader = torch.utils.data.DataLoader( office_dataloader = torch.utils.data.DataLoader(
dataset=office_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4)
dataset=office_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
return office_dataloader return office_dataloader

4
mnist_mnistm.py → experiments/mnist_mnistm.py

@ -1,8 +1,10 @@
import os import os
import sys
sys.path.append('../')
from models.model import MNISTmodel from models.model import MNISTmodel
from core.dann import train_dann from core.dann import train_dann
from utils import get_data_loader, init_model, init_random_seed
from utils.utils import get_data_loader, init_model, init_random_seed
class Config(object): class Config(object):

21
office.py → experiments/office.py

@ -1,15 +1,19 @@
import os import os
import sys
import torch
sys.path.append('../')
from core.dann import train_dann from core.dann import train_dann
from core.test import eval from core.test import eval
from models.model import AlexModel from models.model import AlexModel
from utils import get_data_loader, init_model, init_random_seed
from utils.utils import get_data_loader, init_model, init_random_seed
class Config(object): class Config(object):
# params for path # params for path
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) dataset_root = os.path.expanduser(os.path.join('~', 'Datasets'))
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN'))
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-dann'))
# params for datasets and data loader # params for datasets and data loader
batch_size = 32 batch_size = 32
@ -28,12 +32,13 @@ class Config(object):
num_epochs_src = 100 num_epochs_src = 100
log_step_src = 5 log_step_src = 5
save_step_src = 50 save_step_src = 50
eval_step_src = 20
eval_step_src = 10
# params for training dann # params for training dann
gpu_id = '0'
## for office ## for office
num_epochs = 2000
num_epochs = 1000
log_step = 10 # iters log_step = 10 # iters
save_step = 500 save_step = 500
eval_step = 5 # epochs eval_step = 5 # epochs
@ -44,11 +49,15 @@ class Config(object):
# params for optimizing models # params for optimizing models
lr = 2e-4 lr = 2e-4
params = Config() params = Config()
# init random seed # init random seed
init_random_seed(params.manual_seed) init_random_seed(params.manual_seed)
# init device
device = torch.device("cuda:" + params.gpu_id if torch.cuda.is_available() else "cpu")
# load dataset # load dataset
src_data_loader = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size) src_data_loader = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size)
tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size) tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size)
@ -60,7 +69,7 @@ dann = init_model(net=AlexModel(), restore=None)
print("Start training dann model.") print("Start training dann model.")
if not (dann.restored and params.dann_restore): if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader)
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader, device)
# eval dann model # eval dann model
print("Evaluating dann for source domain") print("Evaluating dann for source domain")

81
experiments/office31_10.py

@ -0,0 +1,81 @@
import os
import sys
sys.path.append('../')
from core.dann import train_dann
from core.test import eval
from models.model import AlexModel
from utils.utils import get_data_loader, init_model, init_random_seed
class Config(object):
# params for path
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets'))
model_root = os.path.expanduser(
os.path.join('~', 'Models', 'pytorch-DANN'))
# params for datasets and data loader
batch_size = 32
# params for source dataset
src_dataset = "amazon31"
src_model_trained = True
src_classifier_restore = os.path.join(
model_root, src_dataset + '-source-classifier-final.pt')
# params for target dataset
tgt_dataset = "webcam10"
tgt_model_trained = True
dann_restore = os.path.join(
model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt')
# params for pretrain
num_epochs_src = 100
log_step_src = 5
save_step_src = 50
eval_step_src = 20
# params for training dann
# for office
num_epochs = 1000
log_step = 10 # iters
save_step = 500
eval_step = 5 # epochs
manual_seed = 8888
alpha = 0
# params for optimizing models
lr = 2e-4
params = Config()
# init random seed
init_random_seed(params.manual_seed)
# load dataset
src_data_loader = get_data_loader(
params.src_dataset, params.dataset_root, params.batch_size)
tgt_data_loader = get_data_loader(
params.tgt_dataset, params.dataset_root, params.batch_size)
# load dann model
dann = init_model(net=AlexModel(), restore=None)
# train dann model
print("Start training dann model.")
if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader,
tgt_data_loader, tgt_data_loader)
# eval dann model
print("Evaluating dann for source domain")
eval(dann, src_data_loader)
print("Evaluating dann for target domain")
eval(dann, tgt_data_loader)
print('done')

4
svhn_mnist.py → experiments/svhn_mnist.py

@ -1,8 +1,10 @@
import os import os
import sys
sys.path.append('../')
from models.model import SVHNmodel from models.model import SVHNmodel
from core.dann import train_dann from core.dann import train_dann
from utils import get_data_loader, init_model, init_random_seed
from utils.utils import get_data_loader, init_model, init_random_seed
class Config(object): class Config(object):

4
models/model.py

@ -179,10 +179,10 @@ class AlexModel(nn.Module):
self.discriminator = nn.Sequential( self.discriminator = nn.Sequential(
nn.Linear(2048, 1024), nn.Linear(2048, 1024),
nn.ReLU(),
nn.ReLU(inplace=True),
nn.Dropout(), nn.Dropout(),
nn.Linear(1024, 1024), nn.Linear(1024, 1024),
nn.ReLU(),
nn.ReLU(inplace=True),
nn.Dropout(), nn.Dropout(),
nn.Linear(1024, 2), nn.Linear(1024, 2),
) )

14
utils.py → utils/utils.py

@ -1,5 +1,3 @@
"""Utilities for ADDA."""
import os import os
import random import random
@ -12,13 +10,6 @@ from datasets.office import get_office
from datasets.officecaltech import get_officecaltech from datasets.officecaltech import get_officecaltech
def make_variable(tensor, volatile=False):
"""Convert Tensor to Variable."""
if torch.cuda.is_available():
tensor = tensor.cuda()
return Variable(tensor, volatile=volatile)
def make_cuda(tensor): def make_cuda(tensor):
"""Use CUDA if it's available.""" """Use CUDA if it's available."""
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -71,6 +62,7 @@ def get_data_loader(name, dataset_root, batch_size, train=True):
elif name == "webcam10": elif name == "webcam10":
return get_officecaltech(dataset_root, batch_size, 'webcam') return get_officecaltech(dataset_root, batch_size, 'webcam')
def init_model(net, restore): def init_model(net, restore):
"""Init models with cuda and weights.""" """Init models with cuda and weights."""
# init weights of model # init weights of model
@ -91,10 +83,10 @@ def init_model(net, restore):
return net return net
def save_model(net, model_root, filename): def save_model(net, model_root, filename):
"""Save trained model.""" """Save trained model."""
if not os.path.exists(model_root): if not os.path.exists(model_root):
os.makedirs(model_root) os.makedirs(model_root)
torch.save(net.state_dict(),
os.path.join(model_root, filename))
torch.save(net.state_dict(), os.path.join(model_root, filename))
print("save pretrained model to: {}".format(os.path.join(model_root, filename))) print("save pretrained model to: {}".format(os.path.join(model_root, filename)))
Loading…
Cancel
Save