wogong
5 years ago
7 changed files with 307 additions and 10 deletions
@ -0,0 +1,44 @@ |
|||
"""Dataset setting and data loader for GTSRB.""" |
|||
|
|||
import os |
|||
import torch |
|||
from torchvision import datasets, transforms |
|||
import torch.utils.data as data |
|||
from torch.utils.data.sampler import SubsetRandomSampler |
|||
import numpy as np |
|||
|
|||
|
|||
def get_gtsrb(dataset_root, batch_size, train): |
|||
"""Get GTSRB datasets loader.""" |
|||
shuffle_dataset = True |
|||
random_seed = 42 |
|||
train_size = 31367 |
|||
|
|||
# image pre-processing |
|||
pre_process = transforms.Compose([ |
|||
transforms.Resize((40, 40)), |
|||
transforms.ToTensor(), |
|||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) |
|||
]) |
|||
|
|||
# datasets and data_loader |
|||
gtsrb_dataset = datasets.ImageFolder( |
|||
os.path.join(dataset_root, 'Final_Training', 'Images'), transform=pre_process) |
|||
|
|||
dataset_size = len(gtsrb_dataset) |
|||
indices = list(range(dataset_size)) |
|||
if shuffle_dataset: |
|||
np.random.seed(random_seed) |
|||
np.random.shuffle(indices) |
|||
train_indices, val_indices = indices[:train_size], indices[train_size:] |
|||
|
|||
# Creating PT data samplers and loaders: |
|||
train_sampler = SubsetRandomSampler(train_indices) |
|||
valid_sampler = SubsetRandomSampler(val_indices) |
|||
|
|||
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, |
|||
sampler=train_sampler) |
|||
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, |
|||
sampler=valid_sampler) |
|||
|
|||
return gtsrb_dataloader_train, gtsrb_dataloader_test |
@ -0,0 +1,62 @@ |
|||
"""Dataset setting and data loader for syn-signs.""" |
|||
|
|||
import os |
|||
import torch |
|||
from torchvision import datasets, transforms |
|||
import torch.utils.data as data |
|||
from PIL import Image |
|||
|
|||
|
|||
class GetLoader(data.Dataset): |
|||
def __init__(self, data_root, data_list, transform=None): |
|||
self.root = data_root |
|||
self.transform = transform |
|||
|
|||
f = open(data_list, 'r') |
|||
data_list = f.readlines() |
|||
f.close() |
|||
|
|||
self.n_data = len(data_list) |
|||
|
|||
self.img_paths = [] |
|||
self.img_labels = [] |
|||
|
|||
for data in data_list: |
|||
data = data.split(' ') |
|||
self.img_paths.append(data[0]) |
|||
self.img_labels.append(data[1]) |
|||
|
|||
def __getitem__(self, item): |
|||
img_paths, labels = self.img_paths[item], self.img_labels[item] |
|||
imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB') |
|||
|
|||
if self.transform is not None: |
|||
imgs = self.transform(imgs) |
|||
labels = int(labels) |
|||
|
|||
return imgs, labels |
|||
|
|||
def __len__(self): |
|||
return self.n_data |
|||
|
|||
def get_synsigns(dataset_root, batch_size, train): |
|||
"""Get Synthetic Signs datasets loader.""" |
|||
# image pre-processing |
|||
pre_process = transforms.Compose([ |
|||
transforms.Resize((40, 40)), |
|||
transforms.ToTensor(), |
|||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) |
|||
]) |
|||
|
|||
# datasets and data_loader |
|||
# using first 90K samples as training set |
|||
train_list = os.path.join(dataset_root, 'train_labelling.txt') |
|||
synsigns_dataset = GetLoader( |
|||
data_root=os.path.join(dataset_root), |
|||
data_list=train_list, |
|||
transform=pre_process) |
|||
|
|||
synsigns_dataloader = torch.utils.data.DataLoader( |
|||
dataset=synsigns_dataset, batch_size=batch_size, shuffle=True, num_workers=0) |
|||
|
|||
return synsigns_dataloader |
@ -0,0 +1,77 @@ |
|||
import os |
|||
import sys |
|||
import datetime |
|||
from tensorboardX import SummaryWriter |
|||
|
|||
import torch |
|||
sys.path.append('../') |
|||
from models.model import GTSRBmodel |
|||
from core.dann import train_dann |
|||
from utils.utils import get_data_loader, init_model, init_random_seed |
|||
|
|||
|
|||
class Config(object): |
|||
# params for path |
|||
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets')) |
|||
model_name = "synsigns-gtsrb" |
|||
model_base = '/home/wogong/models/pytorch-dann' |
|||
note = '' |
|||
now = datetime.datetime.now().strftime('%m%d_%H%M%S') |
|||
model_root = os.path.join(model_base, model_name, note + '_' + now) |
|||
finetune_flag = False |
|||
|
|||
# params for datasets and data loader |
|||
batch_size = 128 |
|||
|
|||
# params for source dataset |
|||
src_dataset = "synsigns" |
|||
source_image_root = os.path.join('/home/wogong/datasets', 'synsigns') |
|||
src_model_trained = True |
|||
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt') |
|||
|
|||
# params for target dataset |
|||
tgt_dataset = "gtsrb" |
|||
target_image_root = os.path.join('/home/wogong/datasets', 'gtsrb') |
|||
tgt_model_trained = True |
|||
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt') |
|||
|
|||
# params for pretrain |
|||
num_epochs_src = 100 |
|||
log_step_src = 10 |
|||
save_step_src = 50 |
|||
eval_step_src = 20 |
|||
|
|||
# params for training dann |
|||
gpu_id = '0' |
|||
|
|||
## for digit |
|||
num_epochs = 200 |
|||
log_step = 50 |
|||
save_step = 100 |
|||
eval_step = 5 |
|||
|
|||
manual_seed = None |
|||
alpha = 0 |
|||
|
|||
# params for optimizing models |
|||
lr = 2e-4 |
|||
|
|||
params = Config() |
|||
logger = SummaryWriter(params.model_root) |
|||
device = torch.device("cuda:" + params.gpu_id if torch.cuda.is_available() else "cpu") |
|||
|
|||
# init random seed |
|||
init_random_seed(params.manual_seed) |
|||
|
|||
# load dataset |
|||
src_data_loader = get_data_loader(params.src_dataset, params.source_image_root, params.batch_size, train=True) |
|||
src_data_loader_eval = get_data_loader(params.src_dataset, params.source_image_root, params.batch_size, train=False) |
|||
tgt_data_loader, tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.target_image_root, params.batch_size, train=True) |
|||
|
|||
# load dann model |
|||
dann = init_model(net=GTSRBmodel(), restore=None) |
|||
|
|||
# train dann model |
|||
print("Training dann model") |
|||
if not (dann.restored and params.dann_restore): |
|||
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger) |
Loading…
Reference in new issue