wogong
5 years ago
2 changed files with 0 additions and 106 deletions
@ -1,61 +0,0 @@ |
|||
"""Train classifier for source dataset.""" |
|||
|
|||
import torch.nn as nn |
|||
import torch.optim as optim |
|||
|
|||
from utils.utils import save_model |
|||
from core.test import test |
|||
|
|||
|
|||
def train_src(model, params, data_loader, device): |
|||
"""Train classifier for source domain.""" |
|||
#################### |
|||
# 1. setup network # |
|||
#################### |
|||
|
|||
# set train state for Dropout and BN layers |
|||
model.train() |
|||
|
|||
# setup criterion and optimizer |
|||
optimizer = optim.Adam(model.parameters(), lr=params.lr) |
|||
loss_class = nn.NLLLoss() |
|||
|
|||
#################### |
|||
# 2. train network # |
|||
#################### |
|||
|
|||
for epoch in range(params.num_epochs_src): |
|||
for step, (images, labels) in enumerate(data_loader): |
|||
# make images and labels variable |
|||
images = images.to(device) |
|||
labels = labels.squeeze_().to(device) |
|||
|
|||
# zero gradients for optimizer |
|||
optimizer.zero_grad() |
|||
|
|||
# compute loss for critic |
|||
preds_class, _ = model(images) |
|||
loss = loss_class(preds_class, labels) |
|||
|
|||
# optimize source classifier |
|||
loss.backward() |
|||
optimizer.step() |
|||
|
|||
# print step info |
|||
if ((step + 1) % params.log_step_src == 0): |
|||
print("Epoch [{}/{}] Step [{}/{}]: loss={}".format(epoch + 1, params.num_epochs_src, step + 1, |
|||
len(data_loader), loss.data[0])) |
|||
|
|||
# eval model on test set |
|||
if ((epoch + 1) % params.eval_step_src == 0): |
|||
test(model, data_loader, flag='source') |
|||
model.train() |
|||
|
|||
# save model parameters |
|||
if ((epoch + 1) % params.save_step_src == 0): |
|||
save_model(model, params.src_dataset + "-source-classifier-{}.pt".format(epoch + 1)) |
|||
|
|||
# save final model |
|||
save_model(model, params.src_dataset + "-source-classifier-final.pt") |
|||
|
|||
return model |
@ -1,45 +0,0 @@ |
|||
"""Dataset setting and data loader for GTSRB.""" |
|||
|
|||
import os |
|||
import torch |
|||
from torchvision import datasets, transforms |
|||
import torch.utils.data as data |
|||
from torch.utils.data.sampler import SubsetRandomSampler |
|||
import numpy as np |
|||
|
|||
def get_gtsrb(dataset_root, batch_size, train): |
|||
"""Get GTSRB datasets loader.""" |
|||
shuffle_dataset = True |
|||
random_seed = 42 |
|||
train_size = 31367 |
|||
|
|||
# image pre-processing |
|||
pre_process = transforms.Compose([ |
|||
transforms.Resize((40, 40)), |
|||
transforms.ToTensor(), |
|||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) |
|||
]) |
|||
|
|||
# datasets and data_loader |
|||
gtsrb_dataset = datasets.ImageFolder( |
|||
os.path.join(dataset_root, 'Final_Training', 'Images'), transform=pre_process) |
|||
|
|||
dataset_size = len(gtsrb_dataset) |
|||
indices = list(range(dataset_size)) |
|||
if shuffle_dataset: |
|||
np.random.seed(random_seed) |
|||
np.random.shuffle(indices) |
|||
train_indices, val_indices = indices[:train_size], indices[train_size:] |
|||
|
|||
# Creating PT data samplers and loaders: |
|||
train_sampler = SubsetRandomSampler(train_indices) |
|||
valid_sampler = SubsetRandomSampler(val_indices) |
|||
|
|||
if train: |
|||
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, |
|||
sampler=train_sampler) |
|||
return gtsrb_dataloader_train |
|||
else: |
|||
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, |
|||
sampler=valid_sampler) |
|||
return gtsrb_dataloader_test |
Loading…
Reference in new issue