wogong
5 years ago
2 changed files with 0 additions and 106 deletions
@ -1,61 +0,0 @@ |
|||||
"""Train classifier for source dataset.""" |
|
||||
|
|
||||
import torch.nn as nn |
|
||||
import torch.optim as optim |
|
||||
|
|
||||
from utils.utils import save_model |
|
||||
from core.test import test |
|
||||
|
|
||||
|
|
||||
def train_src(model, params, data_loader, device): |
|
||||
"""Train classifier for source domain.""" |
|
||||
#################### |
|
||||
# 1. setup network # |
|
||||
#################### |
|
||||
|
|
||||
# set train state for Dropout and BN layers |
|
||||
model.train() |
|
||||
|
|
||||
# setup criterion and optimizer |
|
||||
optimizer = optim.Adam(model.parameters(), lr=params.lr) |
|
||||
loss_class = nn.NLLLoss() |
|
||||
|
|
||||
#################### |
|
||||
# 2. train network # |
|
||||
#################### |
|
||||
|
|
||||
for epoch in range(params.num_epochs_src): |
|
||||
for step, (images, labels) in enumerate(data_loader): |
|
||||
# make images and labels variable |
|
||||
images = images.to(device) |
|
||||
labels = labels.squeeze_().to(device) |
|
||||
|
|
||||
# zero gradients for optimizer |
|
||||
optimizer.zero_grad() |
|
||||
|
|
||||
# compute loss for critic |
|
||||
preds_class, _ = model(images) |
|
||||
loss = loss_class(preds_class, labels) |
|
||||
|
|
||||
# optimize source classifier |
|
||||
loss.backward() |
|
||||
optimizer.step() |
|
||||
|
|
||||
# print step info |
|
||||
if ((step + 1) % params.log_step_src == 0): |
|
||||
print("Epoch [{}/{}] Step [{}/{}]: loss={}".format(epoch + 1, params.num_epochs_src, step + 1, |
|
||||
len(data_loader), loss.data[0])) |
|
||||
|
|
||||
# eval model on test set |
|
||||
if ((epoch + 1) % params.eval_step_src == 0): |
|
||||
test(model, data_loader, flag='source') |
|
||||
model.train() |
|
||||
|
|
||||
# save model parameters |
|
||||
if ((epoch + 1) % params.save_step_src == 0): |
|
||||
save_model(model, params.src_dataset + "-source-classifier-{}.pt".format(epoch + 1)) |
|
||||
|
|
||||
# save final model |
|
||||
save_model(model, params.src_dataset + "-source-classifier-final.pt") |
|
||||
|
|
||||
return model |
|
@ -1,45 +0,0 @@ |
|||||
"""Dataset setting and data loader for GTSRB.""" |
|
||||
|
|
||||
import os |
|
||||
import torch |
|
||||
from torchvision import datasets, transforms |
|
||||
import torch.utils.data as data |
|
||||
from torch.utils.data.sampler import SubsetRandomSampler |
|
||||
import numpy as np |
|
||||
|
|
||||
def get_gtsrb(dataset_root, batch_size, train): |
|
||||
"""Get GTSRB datasets loader.""" |
|
||||
shuffle_dataset = True |
|
||||
random_seed = 42 |
|
||||
train_size = 31367 |
|
||||
|
|
||||
# image pre-processing |
|
||||
pre_process = transforms.Compose([ |
|
||||
transforms.Resize((40, 40)), |
|
||||
transforms.ToTensor(), |
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) |
|
||||
]) |
|
||||
|
|
||||
# datasets and data_loader |
|
||||
gtsrb_dataset = datasets.ImageFolder( |
|
||||
os.path.join(dataset_root, 'Final_Training', 'Images'), transform=pre_process) |
|
||||
|
|
||||
dataset_size = len(gtsrb_dataset) |
|
||||
indices = list(range(dataset_size)) |
|
||||
if shuffle_dataset: |
|
||||
np.random.seed(random_seed) |
|
||||
np.random.shuffle(indices) |
|
||||
train_indices, val_indices = indices[:train_size], indices[train_size:] |
|
||||
|
|
||||
# Creating PT data samplers and loaders: |
|
||||
train_sampler = SubsetRandomSampler(train_indices) |
|
||||
valid_sampler = SubsetRandomSampler(val_indices) |
|
||||
|
|
||||
if train: |
|
||||
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, |
|
||||
sampler=train_sampler) |
|
||||
return gtsrb_dataloader_train |
|
||||
else: |
|
||||
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, |
|
||||
sampler=valid_sampler) |
|
||||
return gtsrb_dataloader_test |
|
Loading…
Reference in new issue