wogong
5 years ago
7 changed files with 307 additions and 10 deletions
@ -0,0 +1,44 @@ |
|||||
|
"""Dataset setting and data loader for GTSRB.""" |
||||
|
|
||||
|
import os |
||||
|
import torch |
||||
|
from torchvision import datasets, transforms |
||||
|
import torch.utils.data as data |
||||
|
from torch.utils.data.sampler import SubsetRandomSampler |
||||
|
import numpy as np |
||||
|
|
||||
|
|
||||
|
def get_gtsrb(dataset_root, batch_size, train): |
||||
|
"""Get GTSRB datasets loader.""" |
||||
|
shuffle_dataset = True |
||||
|
random_seed = 42 |
||||
|
train_size = 31367 |
||||
|
|
||||
|
# image pre-processing |
||||
|
pre_process = transforms.Compose([ |
||||
|
transforms.Resize((40, 40)), |
||||
|
transforms.ToTensor(), |
||||
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) |
||||
|
]) |
||||
|
|
||||
|
# datasets and data_loader |
||||
|
gtsrb_dataset = datasets.ImageFolder( |
||||
|
os.path.join(dataset_root, 'Final_Training', 'Images'), transform=pre_process) |
||||
|
|
||||
|
dataset_size = len(gtsrb_dataset) |
||||
|
indices = list(range(dataset_size)) |
||||
|
if shuffle_dataset: |
||||
|
np.random.seed(random_seed) |
||||
|
np.random.shuffle(indices) |
||||
|
train_indices, val_indices = indices[:train_size], indices[train_size:] |
||||
|
|
||||
|
# Creating PT data samplers and loaders: |
||||
|
train_sampler = SubsetRandomSampler(train_indices) |
||||
|
valid_sampler = SubsetRandomSampler(val_indices) |
||||
|
|
||||
|
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, |
||||
|
sampler=train_sampler) |
||||
|
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, |
||||
|
sampler=valid_sampler) |
||||
|
|
||||
|
return gtsrb_dataloader_train, gtsrb_dataloader_test |
@ -0,0 +1,62 @@ |
|||||
|
"""Dataset setting and data loader for syn-signs.""" |
||||
|
|
||||
|
import os |
||||
|
import torch |
||||
|
from torchvision import datasets, transforms |
||||
|
import torch.utils.data as data |
||||
|
from PIL import Image |
||||
|
|
||||
|
|
||||
|
class GetLoader(data.Dataset): |
||||
|
def __init__(self, data_root, data_list, transform=None): |
||||
|
self.root = data_root |
||||
|
self.transform = transform |
||||
|
|
||||
|
f = open(data_list, 'r') |
||||
|
data_list = f.readlines() |
||||
|
f.close() |
||||
|
|
||||
|
self.n_data = len(data_list) |
||||
|
|
||||
|
self.img_paths = [] |
||||
|
self.img_labels = [] |
||||
|
|
||||
|
for data in data_list: |
||||
|
data = data.split(' ') |
||||
|
self.img_paths.append(data[0]) |
||||
|
self.img_labels.append(data[1]) |
||||
|
|
||||
|
def __getitem__(self, item): |
||||
|
img_paths, labels = self.img_paths[item], self.img_labels[item] |
||||
|
imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB') |
||||
|
|
||||
|
if self.transform is not None: |
||||
|
imgs = self.transform(imgs) |
||||
|
labels = int(labels) |
||||
|
|
||||
|
return imgs, labels |
||||
|
|
||||
|
def __len__(self): |
||||
|
return self.n_data |
||||
|
|
||||
|
def get_synsigns(dataset_root, batch_size, train): |
||||
|
"""Get Synthetic Signs datasets loader.""" |
||||
|
# image pre-processing |
||||
|
pre_process = transforms.Compose([ |
||||
|
transforms.Resize((40, 40)), |
||||
|
transforms.ToTensor(), |
||||
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) |
||||
|
]) |
||||
|
|
||||
|
# datasets and data_loader |
||||
|
# using first 90K samples as training set |
||||
|
train_list = os.path.join(dataset_root, 'train_labelling.txt') |
||||
|
synsigns_dataset = GetLoader( |
||||
|
data_root=os.path.join(dataset_root), |
||||
|
data_list=train_list, |
||||
|
transform=pre_process) |
||||
|
|
||||
|
synsigns_dataloader = torch.utils.data.DataLoader( |
||||
|
dataset=synsigns_dataset, batch_size=batch_size, shuffle=True, num_workers=0) |
||||
|
|
||||
|
return synsigns_dataloader |
@ -0,0 +1,77 @@ |
|||||
|
import os |
||||
|
import sys |
||||
|
import datetime |
||||
|
from tensorboardX import SummaryWriter |
||||
|
|
||||
|
import torch |
||||
|
sys.path.append('../') |
||||
|
from models.model import GTSRBmodel |
||||
|
from core.dann import train_dann |
||||
|
from utils.utils import get_data_loader, init_model, init_random_seed |
||||
|
|
||||
|
|
||||
|
class Config(object): |
||||
|
# params for path |
||||
|
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) |
||||
|
model_name = "synsigns-gtsrb" |
||||
|
model_base = '/home/wogong/models/pytorch-dann' |
||||
|
note = '' |
||||
|
now = datetime.datetime.now().strftime('%m%d_%H%M%S') |
||||
|
model_root = os.path.join(model_base, model_name, note + '_' + now) |
||||
|
finetune_flag = False |
||||
|
|
||||
|
# params for datasets and data loader |
||||
|
batch_size = 128 |
||||
|
|
||||
|
# params for source dataset |
||||
|
src_dataset = "synsigns" |
||||
|
source_image_root = os.path.join('/home/wogong/datasets', 'synsigns') |
||||
|
src_model_trained = True |
||||
|
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt') |
||||
|
|
||||
|
# params for target dataset |
||||
|
tgt_dataset = "gtsrb" |
||||
|
target_image_root = os.path.join('/home/wogong/datasets', 'gtsrb') |
||||
|
tgt_model_trained = True |
||||
|
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt') |
||||
|
|
||||
|
# params for pretrain |
||||
|
num_epochs_src = 100 |
||||
|
log_step_src = 10 |
||||
|
save_step_src = 50 |
||||
|
eval_step_src = 20 |
||||
|
|
||||
|
# params for training dann |
||||
|
gpu_id = '0' |
||||
|
|
||||
|
## for digit |
||||
|
num_epochs = 200 |
||||
|
log_step = 50 |
||||
|
save_step = 100 |
||||
|
eval_step = 5 |
||||
|
|
||||
|
manual_seed = None |
||||
|
alpha = 0 |
||||
|
|
||||
|
# params for optimizing models |
||||
|
lr = 2e-4 |
||||
|
|
||||
|
params = Config() |
||||
|
logger = SummaryWriter(params.model_root) |
||||
|
device = torch.device("cuda:" + params.gpu_id if torch.cuda.is_available() else "cpu") |
||||
|
|
||||
|
# init random seed |
||||
|
init_random_seed(params.manual_seed) |
||||
|
|
||||
|
# load dataset |
||||
|
src_data_loader = get_data_loader(params.src_dataset, params.source_image_root, params.batch_size, train=True) |
||||
|
src_data_loader_eval = get_data_loader(params.src_dataset, params.source_image_root, params.batch_size, train=False) |
||||
|
tgt_data_loader, tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.target_image_root, params.batch_size, train=True) |
||||
|
|
||||
|
# load dann model |
||||
|
dann = init_model(net=GTSRBmodel(), restore=None) |
||||
|
|
||||
|
# train dann model |
||||
|
print("Training dann model") |
||||
|
if not (dann.restored and params.dann_restore): |
||||
|
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger) |
Loading…
Reference in new issue