Browse Source

add syndigits-svhn experiment.

master
wogong 5 years ago
parent
commit
0565b9d151
  1. 27
      datasets/syndigits.py
  2. 82
      experiments/syndigits_svhn.py
  3. 6
      models/model.py
  4. 3
      utils/utils.py

27
datasets/syndigits.py

@ -0,0 +1,27 @@
"""Dataset setting and data loader for syn-digits."""
import os
import torch
from torchvision import datasets, transforms
import torch.utils.data as data
def get_syndigits(dataset_root, batch_size, train):
"""Get synth digits datasets loader."""
# image pre-processing
pre_process = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# datasets and data loader
if train:
syndigits_dataset = datasets.ImageFolder(os.path.join(dataset_root, 'TRAIN_separate_dirs'), transform=pre_process)
else:
syndigits_dataset = datasets.ImageFolder(os.path.join(dataset_root, 'TEST_separate_dirs'), transform=pre_process)
syndigits_dataloader = torch.utils.data.DataLoader(
dataset=syndigits_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
return syndigits_dataloader

82
experiments/syndigits_svhn.py

@ -0,0 +1,82 @@
import os
import sys
import datetime
from tensorboardX import SummaryWriter
import torch
sys.path.append('../')
from models.model import SVHNmodel
from core.train import train_dann
from utils.utils import get_data_loader, init_model, init_random_seed
class Config(object):
# params for path
model_name = "syndigits-svhn"
model_base = '/home/wogong/models/pytorch-dann'
note = 'default'
model_root = os.path.join(model_base, model_name, note + '_' + datetime.datetime.now().strftime('%m%d_%H%M%S'))
os.makedirs(model_root)
config = os.path.join(model_root, 'config.txt')
finetune_flag = False
lr_adjust_flag = 'simple'
src_only_flag = False
# params for datasets and data loader
batch_size = 128
# params for source dataset
src_dataset = "syndigits"
src_image_root = os.path.join('/home/wogong/datasets', 'syndigits')
src_model_trained = True
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt')
# params for target dataset
tgt_dataset = "svhn"
tgt_image_root = os.path.join('/home/wogong/datasets', 'svhn')
tgt_model_trained = True
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt')
# params for GPU device
gpu_id = '0'
## for digit
num_epochs = 200
log_step = 200
save_step = 100
eval_step = 1
manual_seed = 42
alpha = 0
# params for SGD optimizer
lr = 0.01
momentum = 0.9
weight_decay = 1e-6
def __init__(self):
public_props = (name for name in dir(self) if not name.startswith('_'))
with open(self.config, 'w') as f:
for name in public_props:
f.write(name + ': ' + str(getattr(self, name)) + '\n')
params = Config()
logger = SummaryWriter(params.model_root)
device = torch.device("cuda:" + params.gpu_id if torch.cuda.is_available() else "cpu")
# init random seed
init_random_seed(params.manual_seed)
# load dataset
src_data_loader = get_data_loader(params.src_dataset, params.src_image_root, params.batch_size, train=True)
src_data_loader_eval = get_data_loader(params.src_dataset, params.src_image_root, params.batch_size, train=False)
tgt_data_loader = get_data_loader(params.tgt_dataset, params.tgt_image_root, params.batch_size, train=True)
tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.tgt_image_root, params.batch_size, train=False)
# load dann model
dann = init_model(net=SVHNmodel(), restore=None)
# train dann model
print("Training dann model")
if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger)

6
models/model.py

@ -160,12 +160,12 @@ class SVHNmodel(nn.Module):
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(5, 5)), # 28
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2)), # 14
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(5, 5)), # 10
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2)), # 13
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(5, 5)), # 9
nn.BatchNorm2d(64),
nn.Dropout2d(),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2)), # 5
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2)), # 4
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(4, 4)), # 1
)

3
utils/utils.py

@ -7,6 +7,7 @@ import torch.backends.cudnn as cudnn
from datasets import get_mnist, get_mnistm, get_svhn
from datasets.office import get_office
from datasets.officecaltech import get_officecaltech
from datasets.syndigits import get_syndigits
from datasets.synsigns import get_synsigns
from datasets.gtsrb import get_gtsrb
@ -61,6 +62,8 @@ def get_data_loader(name, dataset_root, batch_size, train=True):
return get_office(dataset_root, batch_size, 'webcam')
elif name == "webcam10":
return get_officecaltech(dataset_root, batch_size, 'webcam')
elif name == "syndigits":
return get_syndigits(dataset_root, batch_size, train)
elif name == "synsigns":
return get_synsigns(dataset_root, batch_size, train)
elif name == "gtsrb":

Loading…
Cancel
Save