Browse Source

remove legacy code.

master
wogong 5 years ago
parent
commit
0c3226fac6
  1. 61
      core/pretrain.py
  2. 45
      datasets/gtsrb_legacy.py

61
core/pretrain.py

@ -1,61 +0,0 @@
"""Train classifier for source dataset."""
import torch.nn as nn
import torch.optim as optim
from utils.utils import save_model
from core.test import test
def train_src(model, params, data_loader, device):
"""Train classifier for source domain."""
####################
# 1. setup network #
####################
# set train state for Dropout and BN layers
model.train()
# setup criterion and optimizer
optimizer = optim.Adam(model.parameters(), lr=params.lr)
loss_class = nn.NLLLoss()
####################
# 2. train network #
####################
for epoch in range(params.num_epochs_src):
for step, (images, labels) in enumerate(data_loader):
# make images and labels variable
images = images.to(device)
labels = labels.squeeze_().to(device)
# zero gradients for optimizer
optimizer.zero_grad()
# compute loss for critic
preds_class, _ = model(images)
loss = loss_class(preds_class, labels)
# optimize source classifier
loss.backward()
optimizer.step()
# print step info
if ((step + 1) % params.log_step_src == 0):
print("Epoch [{}/{}] Step [{}/{}]: loss={}".format(epoch + 1, params.num_epochs_src, step + 1,
len(data_loader), loss.data[0]))
# eval model on test set
if ((epoch + 1) % params.eval_step_src == 0):
test(model, data_loader, flag='source')
model.train()
# save model parameters
if ((epoch + 1) % params.save_step_src == 0):
save_model(model, params.src_dataset + "-source-classifier-{}.pt".format(epoch + 1))
# save final model
save_model(model, params.src_dataset + "-source-classifier-final.pt")
return model

45
datasets/gtsrb_legacy.py

@ -1,45 +0,0 @@
"""Dataset setting and data loader for GTSRB."""
import os
import torch
from torchvision import datasets, transforms
import torch.utils.data as data
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
def get_gtsrb(dataset_root, batch_size, train):
"""Get GTSRB datasets loader."""
shuffle_dataset = True
random_seed = 42
train_size = 31367
# image pre-processing
pre_process = transforms.Compose([
transforms.Resize((40, 40)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# datasets and data_loader
gtsrb_dataset = datasets.ImageFolder(
os.path.join(dataset_root, 'Final_Training', 'Images'), transform=pre_process)
dataset_size = len(gtsrb_dataset)
indices = list(range(dataset_size))
if shuffle_dataset:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[:train_size], indices[train_size:]
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
if train:
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size,
sampler=train_sampler)
return gtsrb_dataloader_train
else:
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size,
sampler=valid_sampler)
return gtsrb_dataloader_test
Loading…
Cancel
Save