Browse Source

add synsigns_gtsrb experiment, update core codes

master
wogong 5 years ago
parent
commit
232135eed0
  1. 28
      core/dann.py
  2. 2
      core/test.py
  3. 44
      datasets/gtsrb.py
  4. 62
      datasets/synsigns.py
  5. 77
      experiments/synsigns_gtsrb.py
  6. 96
      models/model.py
  7. 8
      utils/utils.py

28
core/dann.py

@ -5,15 +5,13 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from core.test import test from core.test import test
from utils.utils import save_model from utils.utils import save_model
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
cudnn.benchmark = True cudnn.benchmark = True
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device):
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger):
"""Train dann.""" """Train dann."""
#################### ####################
# 1. setup network # # 1. setup network #
@ -21,8 +19,8 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
# setup criterion and optimizer # setup criterion and optimizer
if params.src_dataset == 'mnist' or params.tgt_dataset == 'mnist':
print("training mnist task")
if not params.finetune_flag:
print("training non-office task")
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
else: else:
print("training office task") print("training office task")
@ -46,7 +44,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
#################### ####################
# 2. train network # # 2. train network #
#################### ####################
global_step = 0
for epoch in range(params.num_epochs): for epoch in range(params.num_epochs):
# set train state for Dropout and BN layers # set train state for Dropout and BN layers
model.train() model.train()
@ -93,7 +91,14 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
loss.backward() loss.backward()
optimizer.step() optimizer.step()
global_step += 1
# print step info # print step info
logger.add_scalar('src_loss_class', src_loss_class.item(), global_step)
logger.add_scalar('src_loss_domain', src_loss_domain.item(), global_step)
logger.add_scalar('tgt_loss_domain', tgt_loss_domain.item(), global_step)
logger.add_scalar('loss', loss.item(), global_step)
if ((step + 1) % params.log_step == 0): if ((step + 1) % params.log_step == 0):
print( print(
"Epoch [{:4d}/{}] Step [{:2d}/{}]: src_loss_class={:.6f}, src_loss_domain={:.6f}, tgt_loss_domain={:.6f}, loss={:.6f}" "Epoch [{:4d}/{}] Step [{:2d}/{}]: src_loss_class={:.6f}, src_loss_domain={:.6f}, tgt_loss_domain={:.6f}, loss={:.6f}"
@ -103,9 +108,16 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
# eval model # eval model
if ((epoch + 1) % params.eval_step == 0): if ((epoch + 1) % params.eval_step == 0):
print("eval on target domain") print("eval on target domain")
test(model, tgt_data_loader, device, flag='target')
src_test_loss, src_acc, src_acc_domain = test(model, tgt_data_loader, device, flag='target')
print("eval on source domain") print("eval on source domain")
test(model, src_data_loader, device, flag='source')
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, src_data_loader, device, flag='source')
logger.add_scalar('src_test_loss', src_test_loss, global_step)
logger.add_scalar('src_acc', src_acc, global_step)
logger.add_scalar('src_acc_domain', src_acc_domain, global_step)
logger.add_scalar('tgt_test_loss', tgt_test_loss, global_step)
logger.add_scalar('tgt_acc', tgt_acc, global_step)
logger.add_scalar('tgt_acc_domain', tgt_acc_domain, global_step)
# save model parameters # save model parameters
if ((epoch + 1) % params.save_step == 0): if ((epoch + 1) % params.save_step == 0):

2
core/test.py

@ -71,3 +71,5 @@ def test(model, data_loader, device, flag):
acc_domain /= len(data_loader.dataset) acc_domain /= len(data_loader.dataset)
print("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_domain)) print("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_domain))
return loss, acc, acc_domain

44
datasets/gtsrb.py

@ -0,0 +1,44 @@
"""Dataset setting and data loader for GTSRB."""
import os
import torch
from torchvision import datasets, transforms
import torch.utils.data as data
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
def get_gtsrb(dataset_root, batch_size, train):
"""Get GTSRB datasets loader."""
shuffle_dataset = True
random_seed = 42
train_size = 31367
# image pre-processing
pre_process = transforms.Compose([
transforms.Resize((40, 40)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# datasets and data_loader
gtsrb_dataset = datasets.ImageFolder(
os.path.join(dataset_root, 'Final_Training', 'Images'), transform=pre_process)
dataset_size = len(gtsrb_dataset)
indices = list(range(dataset_size))
if shuffle_dataset:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[:train_size], indices[train_size:]
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size,
sampler=train_sampler)
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size,
sampler=valid_sampler)
return gtsrb_dataloader_train, gtsrb_dataloader_test

62
datasets/synsigns.py

@ -0,0 +1,62 @@
"""Dataset setting and data loader for syn-signs."""
import os
import torch
from torchvision import datasets, transforms
import torch.utils.data as data
from PIL import Image
class GetLoader(data.Dataset):
def __init__(self, data_root, data_list, transform=None):
self.root = data_root
self.transform = transform
f = open(data_list, 'r')
data_list = f.readlines()
f.close()
self.n_data = len(data_list)
self.img_paths = []
self.img_labels = []
for data in data_list:
data = data.split(' ')
self.img_paths.append(data[0])
self.img_labels.append(data[1])
def __getitem__(self, item):
img_paths, labels = self.img_paths[item], self.img_labels[item]
imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB')
if self.transform is not None:
imgs = self.transform(imgs)
labels = int(labels)
return imgs, labels
def __len__(self):
return self.n_data
def get_synsigns(dataset_root, batch_size, train):
"""Get Synthetic Signs datasets loader."""
# image pre-processing
pre_process = transforms.Compose([
transforms.Resize((40, 40)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# datasets and data_loader
# using first 90K samples as training set
train_list = os.path.join(dataset_root, 'train_labelling.txt')
synsigns_dataset = GetLoader(
data_root=os.path.join(dataset_root),
data_list=train_list,
transform=pre_process)
synsigns_dataloader = torch.utils.data.DataLoader(
dataset=synsigns_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
return synsigns_dataloader

77
experiments/synsigns_gtsrb.py

@ -0,0 +1,77 @@
import os
import sys
import datetime
from tensorboardX import SummaryWriter
import torch
sys.path.append('../')
from models.model import GTSRBmodel
from core.dann import train_dann
from utils.utils import get_data_loader, init_model, init_random_seed
class Config(object):
# params for path
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets'))
model_name = "synsigns-gtsrb"
model_base = '/home/wogong/models/pytorch-dann'
note = ''
now = datetime.datetime.now().strftime('%m%d_%H%M%S')
model_root = os.path.join(model_base, model_name, note + '_' + now)
finetune_flag = False
# params for datasets and data loader
batch_size = 128
# params for source dataset
src_dataset = "synsigns"
source_image_root = os.path.join('/home/wogong/datasets', 'synsigns')
src_model_trained = True
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt')
# params for target dataset
tgt_dataset = "gtsrb"
target_image_root = os.path.join('/home/wogong/datasets', 'gtsrb')
tgt_model_trained = True
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt')
# params for pretrain
num_epochs_src = 100
log_step_src = 10
save_step_src = 50
eval_step_src = 20
# params for training dann
gpu_id = '0'
## for digit
num_epochs = 200
log_step = 50
save_step = 100
eval_step = 5
manual_seed = None
alpha = 0
# params for optimizing models
lr = 2e-4
params = Config()
logger = SummaryWriter(params.model_root)
device = torch.device("cuda:" + params.gpu_id if torch.cuda.is_available() else "cpu")
# init random seed
init_random_seed(params.manual_seed)
# load dataset
src_data_loader = get_data_loader(params.src_dataset, params.source_image_root, params.batch_size, train=True)
src_data_loader_eval = get_data_loader(params.src_dataset, params.source_image_root, params.batch_size, train=False)
tgt_data_loader, tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.target_image_root, params.batch_size, train=True)
# load dann model
dann = init_model(net=GTSRBmodel(), restore=None)
# train dann model
print("Training dann model")
if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger)

96
models/model.py

@ -96,6 +96,58 @@ class MNISTmodel(nn.Module):
return class_output, domain_output return class_output, domain_output
class MNISTmodel_plain(nn.Module):
""" MNIST architecture
+Dropout2d, 84% ~ 73%
-Dropout2d, 50% ~ 73%
"""
def __init__(self):
super(MNISTmodel_plain, self).__init__()
self.restored = False
self.feature = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32,
kernel_size=(5, 5)), # 3 28 28, 32 24 24
#nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 2)), # 32 12 12
nn.Conv2d(in_channels=32, out_channels=48,
kernel_size=(5, 5)), # 48 8 8
#nn.BatchNorm2d(48),
#nn.Dropout2d(),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 2)), # 48 4 4
)
self.classifier = nn.Sequential(
nn.Linear(48*4*4, 100),
#nn.BatchNorm1d(100),
nn.ReLU(inplace=True),
nn.Linear(100, 100),
#nn.BatchNorm1d(100),
nn.ReLU(inplace=True),
nn.Linear(100, 10),
)
self.discriminator = nn.Sequential(
nn.Linear(48*4*4, 100),
#nn.BatchNorm1d(100),
nn.ReLU(inplace=True),
nn.Linear(100, 2),
)
def forward(self, input_data, alpha):
input_data = input_data.expand(input_data.data.shape[0], 3, 28, 28)
feature = self.feature(input_data)
feature = feature.view(-1, 48 * 4 * 4)
reverse_feature = ReverseLayerF.apply(feature, alpha)
class_output = self.classifier(feature)
domain_output = self.discriminator(reverse_feature)
return class_output, domain_output
class SVHNmodel(nn.Module): class SVHNmodel(nn.Module):
""" SVHN architecture """ SVHN architecture
I don't know how to implement the paper's structure I don't know how to implement the paper's structure
@ -152,6 +204,50 @@ class SVHNmodel(nn.Module):
return class_output, domain_output return class_output, domain_output
class GTSRBmodel(nn.Module):
""" GTSRB architecture
"""
def __init__(self):
super(GTSRBmodel, self).__init__()
self.restored = False
self.feature = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=96, kernel_size=(5, 5)), # 36
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), # 18
nn.Conv2d(in_channels=96, out_channels=144, kernel_size=(3, 3)), # 16
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), # 8
nn.Conv2d(in_channels=144, out_channels=256, kernel_size=(5, 5)), # 4
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), # 2
)
self.classifier = nn.Sequential(
nn.Linear(256 * 2 * 2, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 43),
)
self.discriminator = nn.Sequential(
nn.Linear(256 * 2 * 2, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 2),
)
def forward(self, input_data, alpha = 1.0):
input_data = input_data.expand(input_data.data.shape[0], 3, 40, 40)
feature = self.feature(input_data)
feature = feature.view(-1, 256 * 2 * 2)
reverse_feature = ReverseLayerF.apply(feature, alpha)
class_output = self.classifier(feature)
domain_output = self.discriminator(reverse_feature)
return class_output, domain_output
class AlexModel(nn.Module): class AlexModel(nn.Module):
""" AlexNet pretrained on imagenet for Office dataset""" """ AlexNet pretrained on imagenet for Office dataset"""

8
utils/utils.py

@ -3,12 +3,12 @@ import random
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from datasets import get_mnist, get_mnistm, get_svhn from datasets import get_mnist, get_mnistm, get_svhn
from datasets.office import get_office from datasets.office import get_office
from datasets.officecaltech import get_officecaltech from datasets.officecaltech import get_officecaltech
from datasets.synsigns import get_synsigns
from datasets.gtsrb import get_gtsrb
def make_cuda(tensor): def make_cuda(tensor):
"""Use CUDA if it's available.""" """Use CUDA if it's available."""
@ -61,6 +61,10 @@ def get_data_loader(name, dataset_root, batch_size, train=True):
return get_office(dataset_root, batch_size, 'webcam') return get_office(dataset_root, batch_size, 'webcam')
elif name == "webcam10": elif name == "webcam10":
return get_officecaltech(dataset_root, batch_size, 'webcam') return get_officecaltech(dataset_root, batch_size, 'webcam')
elif name == "synsigns":
return get_synsigns(dataset_root, batch_size, train)
elif name == "gtsrb":
return get_gtsrb(dataset_root, batch_size, train)
def init_model(net, restore): def init_model(net, restore):

Loading…
Cancel
Save