Browse Source

[WIP]

master
fnakamura 6 years ago
parent
commit
a56c77f190
  1. 80
      experiment.py
  2. 23
      main.py
  3. 77
      models.py
  4. 174
      trainer.py
  5. 32
      utils.py

80
experiment.py

@ -0,0 +1,80 @@
import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import SVHN, MNIST
from torchvision import transforms
from models import CNN, Discriminator
from trainer import train_source_cnn, train_target_cnn
def run(args):
if not os.path.exists(args.logdir):
os.makedirs(args.logdir)
# data
source_transform = transforms.Compose([
# transforms.Grayscale(),
transforms.ToTensor()]
)
target_transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1))
])
source_dataset_train = SVHN(
'./input', 'train', transform=source_transform, download=True)
source_dataset_test = SVHN(
'./input', 'test', transform=source_transform, download=True)
target_dataset_train = MNIST(
'./input', 'train', transform=target_transform, download=True)
target_dataset_test = MNIST(
'./input', 'test', transform=target_transform, download=True)
source_train_loader = DataLoader(
source_dataset_train, args.batch_size, shuffle=True,
drop_last=True,
num_workers=args.n_workers)
source_test_loader = DataLoader(
source_dataset_test, args.batch_size, shuffle=False,
num_workers=args.n_workers)
target_train_loader = DataLoader(
target_dataset_train, args.batch_size, shuffle=True,
drop_last=True,
num_workers=args.n_workers)
target_test_loader = DataLoader(
target_dataset_test, args.batch_size, shuffle=False,
num_workers=args.n_workers)
# train source CNN
source_cnn = CNN(in_channels=args.in_channels).to(args.device)
if os.path.isfile(args.trained):
c = torch.load(args.trained)
source_cnn.load_state_dict(c['model'])
print('Loaded `{}`'.format(args.trained))
else:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
source_cnn.parameters(),
lr=args.lr, weight_decay=args.weight_decay)
source_cnn = train_source_cnn(
source_cnn, source_train_loader, source_test_loader,
criterion, optimizer, args=args)
# train target CNN
target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device)
target_cnn.load_state_dict(source_cnn.state_dict())
discriminator = Discriminator(args=args).to(args.device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.RMSprop( # optim.Adam(
target_cnn.encoder.parameters(),
lr=args.lr, weight_decay=args.weight_decay)
d_optimizer = optim.RMSprop( # optim.Adam(
discriminator.parameters(),
lr=args.lr, weight_decay=args.weight_decay)
train_target_cnn(
source_cnn, target_cnn, discriminator,
criterion, optimizer, d_optimizer,
source_train_loader, target_train_loader, target_test_loader,
args=args)

23
main.py

@ -0,0 +1,23 @@
import argparse
import experiment
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# NN
parser.add_argument('--in_channels', type=int, default=1)
parser.add_argument('--n_classes', type=int, default=10)
parser.add_argument('--trained', type=str, default='')
parser.add_argument('--slope', type=float, default=0.2)
# train
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=0.)
parser.add_argument('--epochs', type=int, default=512)
parser.add_argument('--batch_size', type=int, default=256)
# misc
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--n_workers', type=int, default=0)
parser.add_argument('--logdir', type=str, default='outputs/garbage')
parser.add_argument('--message', '-m', type=str, default='')
args, unknown = parser.parse_known_args()
experiment.run(args)

77
models.py

@ -0,0 +1,77 @@
from torch import nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, in_channels=1, h=256, dropout=0.5):
super(Encoder, self).__init__()
self.conv1 = nn.Conv2d(in_channels, 8, kernel_size=5, stride=1)
self.conv2 = nn.Conv2d(8, 16, kernel_size=5, stride=1)
self.conv3 = nn.Conv2d(16, 120, kernel_size=4, stride=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu = nn.ReLU()
self.dropout1 = nn.Dropout2d(dropout)
self.dropout2 = nn.Dropout(dropout)
self.fc = nn.Linear(480, 500)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
def forward(self, x):
bs = x.size(0)
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
# x = self.dropout1(self.relu(self.conv3(x)))
x = self.relu(self.conv3(x))
x = x.view(bs, -1)
x = self.dropout2(self.fc(x))
return x
class Classifier(nn.Module):
def __init__(self, n_classes, dropout=0.5):
super(Classifier, self).__init__()
self.l1 = nn.Linear(500, n_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
def forward(self, x):
x = self.l1(x)
return x
class CNN(nn.Module):
def __init__(self, in_channels=1, n_classes=10, target=False):
super(CNN, self).__init__()
self.encoder = Encoder(in_channels=in_channels)
self.classifier = Classifier(n_classes)
if target:
for param in self.classifier.parameters():
param.requires_grad = False
def forward(self, x):
x = self.encoder(x)
x = self.classifier(x)
return x
class Discriminator(nn.Module):
def __init__(self, h=500, args=None):
super(Discriminator, self).__init__()
self.l1 = nn.Linear(500, h)
self.l2 = nn.Linear(h, h)
self.l3 = nn.Linear(h, 2)
self.slope = args.slope
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
def forward(self, x):
x = F.leaky_relu(self.l1(x), self.slope)
x = F.leaky_relu(self.l2(x), self.slope)
x = self.l3(x)
return x

174
trainer.py

@ -0,0 +1,174 @@
from time import time
import numpy as np
from sklearn.metrics import accuracy_score
import torch
from utils import AverageMeter, save
def train_source_cnn(
source_cnn, train_loader, test_loader, criterion, optimizer,
args=None
):
best_score = None
for epoch_i in range(1, 1 + args.epochs):
start_time = time()
training = train(
source_cnn, train_loader, criterion, optimizer, args=args)
validation = validate(
source_cnn, test_loader, criterion, args=args)
log = 'Epoch {}/{} '.format(epoch_i, args.epochs)
log += 'Train/Loss {:.3f} Train/Acc {:.3f} '.format(
training['loss'], training['acc'])
log += 'Val/Loss {:.3f} Val/Acc {:.3f} '.format(
validation['loss'], validation['acc'])
log += 'Time {:.2f}s'.format(time() - start_time)
print(log)
# save
is_best = (best_score is None or validation['acc'] > best_score)
best_score = validation['acc'] if is_best else best_score
state_dict = {
'model': source_cnn.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch_i,
'val/acc': best_score,
}
save(args.logdir, state_dict, is_best)
return source_cnn
def train_target_cnn(
source_cnn, target_cnn, discriminator,
criterion, optimizer, d_optimizer,
source_train_loader, target_train_loader, target_test_loader,
args=None
):
validation = validate(source_cnn, target_test_loader, criterion, args=args)
log_source = 'Source/Val/Acc {:.3f} '.format(validation['acc'])
# best_score = None
for epoch_i in range(1, 1 + args.epochs):
start_time = time()
training = adversarial(
source_cnn, target_cnn, discriminator,
source_train_loader, target_train_loader,
criterion, criterion,
optimizer, d_optimizer,
args=args
)
validation = validate(
target_cnn, target_test_loader, criterion, args=args)
log = 'Epoch {}/{} '.format(epoch_i, args.epochs)
log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format(
training['d/loss'], training['target/loss'])
log += 'Target/Val/Loss {:.3f} Target/Val/Acc {:.3f} '.format(
validation['loss'], validation['acc'])
log += log_source
log += 'Time {:.2f}s'.format(time() - start_time)
print(log)
def adversarial(
source_cnn, target_cnn, discriminator,
source_loader, target_loader,
criterion, d_criterion,
optimizer, d_optimizer,
args=None
):
source_cnn.eval()
target_cnn.train()
discriminator.train()
losses, d_losses = AverageMeter(), AverageMeter()
n_iters = min(len(source_loader), len(target_loader))
source_iter, target_iter = iter(source_loader), iter(target_loader)
for iter_i in range(n_iters):
source_data, source_target = source_iter.next()
target_data, target_target = target_iter.next()
source_data = source_data.to(args.device)
target_data = target_data.to(args.device)
bs = source_data.size(0)
D_input_source = source_cnn.encoder(source_data)
D_input_target = target_cnn.encoder(target_data)
D_target_source = torch.tensor(
[0] * bs, dtype=torch.long).to(args.device)
D_target_target = torch.tensor(
[1] * bs, dtype=torch.long).to(args.device)
# train Discriminator
D_output_source = discriminator(D_input_source)
D_output_target = discriminator(D_input_target)
d_loss_source = d_criterion(D_output_source, D_target_source)
d_loss_target = d_criterion(D_output_target, D_target_target)
d_loss = 0.5 * (d_loss_source + d_loss_target)
d_optimizer.zero_grad()
d_loss.backward(retain_graph=True)
d_optimizer.step()
d_losses.update(d_loss.item(), bs)
# train Target
'''
D_input_target = target_cnn.encoder(target_data)
D_output_target = discriminator(D_input_target)
'''
loss = criterion(D_output_target, D_target_source)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.update(loss.item(), bs)
return {'d/loss': d_losses.avg, 'target/loss': losses.avg}
def step(model, data, target, criterion, args):
data, target = data.to(args.device), target.to(args.device)
output = model(data)
loss = criterion(output, target)
return output, loss
def train(model, dataloader, criterion, optimizer, args=None):
model.train()
losses = AverageMeter()
targets, probas = [], []
for i, (data, target) in enumerate(dataloader):
bs = target.size(0)
output, loss = step(model, data, target, criterion, args)
output = torch.softmax(output, dim=1) # NOTE
losses.update(loss.item(), bs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
targets.extend(target.cpu().detach().numpy().tolist())
probas.extend(output.cpu().detach().numpy().tolist())
probas = np.asarray(probas)
preds = np.argmax(probas, axis=1)
acc = accuracy_score(targets, preds)
return {
'loss': losses.avg, 'acc': acc,
}
def validate(model, dataloader, criterion, args=None):
model.eval()
losses = AverageMeter()
targets, probas = [], []
with torch.no_grad():
for iter_i, (data, target) in enumerate(dataloader):
bs = target.size(0)
output, loss = step(model, data, target, criterion, args)
output = torch.softmax(output, dim=1) # NOTE: check
losses.update(loss.item(), bs)
targets.extend(target.cpu().numpy().tolist())
probas.extend(output.cpu().numpy().tolist())
probas = np.asarray(probas)
preds = np.argmax(probas, axis=1)
acc = accuracy_score(targets, preds)
return {
'loss': losses.avg, 'acc': acc,
}

32
utils.py

@ -0,0 +1,32 @@
import os
import shutil
import torch
def save(log_dir, state_dict, is_best):
checkpoint_path = os.path.join(log_dir, 'checkpoint.pt')
torch.save(state_dict, checkpoint_path)
if is_best:
best_model_path = os.path.join(log_dir, 'best_model.pt')
shutil.copyfile(checkpoint_path, best_model_path)
class AverageMeter(object):
"""Computes and stores the average and current value
https://github.com/pytorch/examples/blob/master/imagenet/main.py#L296
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
Loading…
Cancel
Save