Browse Source

Merge pull request #1 from Fujiki-Nakamura/wip/repro

reproduction: SVHN -> MNIST
master
fnakamura 6 years ago
committed by GitHub
parent
commit
db0012d698
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 21
      README.md
  2. BIN
      adversarial.png
  3. 70
      experiment.py
  4. 24
      main.py
  5. 80
      models.py
  6. BIN
      target_domain.png
  7. 65
      train_source.py
  8. 200
      trainer.py
  9. 48
      utils.py

21
README.md

@ -1,2 +1,23 @@
# ADDA.PyTorch
implement Adversarial Discriminative Domain Adapation in PyTorch
## Example
```
$ python train_source.py --logdir outputs
$ python main.py --logdir outputs --trained outputs/best_model.pt --slope 0.2
```
## Result
### SVHN -> MNIST
| | Paper | This Repro |
| --- | --- | --- |
| Source only | 0.601 | 0.659 |
| ADDA | 0.760 | ~0.83 |
![adversarial](adversarial.png)
![target_domain](target_domain.png)
## Resource
- https://arxiv.org/pdf/1702.05464.pdf

BIN
adversarial.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

70
experiment.py

@ -0,0 +1,70 @@
import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import SVHN, MNIST
from torchvision import transforms
from models import CNN, Discriminator
from trainer import train_target_cnn
from utils import get_logger
def run(args):
if not os.path.exists(args.logdir):
os.makedirs(args.logdir)
logger = get_logger(os.path.join(args.logdir, 'main.log'))
logger.info(args)
# data
source_transform = transforms.Compose([
# transforms.Grayscale(),
transforms.ToTensor()]
)
target_transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1))
])
source_dataset_train = SVHN(
'./input', 'train', transform=source_transform, download=True)
target_dataset_train = MNIST(
'./input', 'train', transform=target_transform, download=True)
target_dataset_test = MNIST(
'./input', 'test', transform=target_transform, download=True)
source_train_loader = DataLoader(
source_dataset_train, args.batch_size, shuffle=True,
drop_last=True,
num_workers=args.n_workers)
target_train_loader = DataLoader(
target_dataset_train, args.batch_size, shuffle=True,
drop_last=True,
num_workers=args.n_workers)
target_test_loader = DataLoader(
target_dataset_test, args.batch_size, shuffle=False,
num_workers=args.n_workers)
# train source CNN
source_cnn = CNN(in_channels=args.in_channels).to(args.device)
if os.path.isfile(args.trained):
c = torch.load(args.trained)
source_cnn.load_state_dict(c['model'])
logger.info('Loaded `{}`'.format(args.trained))
# train target CNN
target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device)
target_cnn.load_state_dict(source_cnn.state_dict())
discriminator = Discriminator(args=args).to(args.device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
target_cnn.encoder.parameters(),
lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
d_optimizer = optim.Adam(
discriminator.parameters(),
lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
train_target_cnn(
source_cnn, target_cnn, discriminator,
criterion, optimizer, d_optimizer,
source_train_loader, target_train_loader, target_test_loader,
args=args)

24
main.py

@ -0,0 +1,24 @@
import argparse
import experiment
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# NN
parser.add_argument('--in_channels', type=int, default=3)
parser.add_argument('--n_classes', type=int, default=10)
parser.add_argument('--trained', type=str, default='')
parser.add_argument('--slope', type=float, default=0.1)
# train
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--weight_decay', type=float, default=2.5e-5)
parser.add_argument('--epochs', type=int, default=512)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--betas', type=float, nargs='+', default=(.5, .999))
# misc
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--n_workers', type=int, default=0)
parser.add_argument('--logdir', type=str, default='outputs/garbage')
parser.add_argument('--message', '-m', type=str, default='')
args, unknown = parser.parse_known_args()
experiment.run(args)

80
models.py

@ -0,0 +1,80 @@
from torch import nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, in_channels=1, h=256, dropout=0.5):
super(Encoder, self).__init__()
self.conv1 = nn.Conv2d(in_channels, 20, kernel_size=5, stride=1)
self.conv2 = nn.Conv2d(20, 50, kernel_size=5, stride=1)
self.bn1 = nn.BatchNorm2d(20)
self.bn2 = nn.BatchNorm2d(50)
# self.conv3 = nn.Conv2d(16, 120, kernel_size=4, stride=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu = nn.ReLU()
# self.dropout1 = nn.Dropout2d(dropout)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(1250, 500)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
def forward(self, x):
bs = x.size(0)
x = self.pool(self.relu(self.bn1(self.conv1(x))))
x = self.pool(self.relu(self.bn2(self.conv2(x))))
# x = self.dropout1(self.relu(self.conv3(x)))
# x = self.relu(self.conv3(x))
x = x.view(bs, -1)
x = self.dropout(x)
x = self.fc(x)
return x
class Classifier(nn.Module):
def __init__(self, n_classes, dropout=0.5):
super(Classifier, self).__init__()
self.l1 = nn.Linear(500, n_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
def forward(self, x):
x = self.l1(x)
return x
class CNN(nn.Module):
def __init__(self, in_channels=1, n_classes=10, target=False):
super(CNN, self).__init__()
self.encoder = Encoder(in_channels=in_channels)
self.classifier = Classifier(n_classes)
if target:
for param in self.classifier.parameters():
param.requires_grad = False
def forward(self, x):
x = self.encoder(x)
x = self.classifier(x)
return x
class Discriminator(nn.Module):
def __init__(self, h=500, args=None):
super(Discriminator, self).__init__()
self.l1 = nn.Linear(500, h)
self.l2 = nn.Linear(h, h)
self.l3 = nn.Linear(h, 2)
self.slope = args.slope
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
def forward(self, x):
x = F.leaky_relu(self.l1(x), self.slope)
x = F.leaky_relu(self.l2(x), self.slope)
x = self.l3(x)
return x

BIN
target_domain.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 145 KiB

65
train_source.py

@ -0,0 +1,65 @@
import argparse
import os
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import SVHN
from torchvision import transforms
from models import CNN
from trainer import train_source_cnn
from utils import get_logger
def main(args):
if not os.path.exists(args.logdir):
os.makedirs(args.logdir)
logger = get_logger(os.path.join(args.logdir, 'train_source.log'))
logger.info(args)
# data
source_transform = transforms.Compose([
transforms.ToTensor()]
)
source_dataset_train = SVHN(
'./input', 'train', transform=source_transform, download=True)
source_dataset_test = SVHN(
'./input', 'test', transform=source_transform, download=True)
source_train_loader = DataLoader(
source_dataset_train, args.batch_size, shuffle=True,
drop_last=True,
num_workers=args.n_workers)
source_test_loader = DataLoader(
source_dataset_test, args.batch_size, shuffle=False,
num_workers=args.n_workers)
# train source CNN
source_cnn = CNN(in_channels=args.in_channels).to(args.device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
source_cnn.parameters(),
lr=args.lr, weight_decay=args.weight_decay)
source_cnn = train_source_cnn(
source_cnn, source_train_loader, source_test_loader,
criterion, optimizer, args=args)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# NN
parser.add_argument('--in_channels', type=int, default=3)
parser.add_argument('--n_classes', type=int, default=10)
parser.add_argument('--trained', type=str, default='')
parser.add_argument('--slope', type=float, default=0.2)
# train
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=2.5e-5)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=128)
# misc
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--n_workers', type=int, default=0)
parser.add_argument('--logdir', type=str, default='outputs/garbage')
parser.add_argument('--message', '-m', type=str, default='')
args, unknown = parser.parse_known_args()
main(args)

200
trainer.py

@ -0,0 +1,200 @@
from logging import getLogger
from time import time
import numpy as np
from sklearn.metrics import accuracy_score
from tensorboardX import SummaryWriter
import torch
from utils import AverageMeter, save
logger = getLogger('adda.trainer')
def train_source_cnn(
source_cnn, train_loader, test_loader, criterion, optimizer,
args=None
):
best_score = None
for epoch_i in range(1, 1 + args.epochs):
start_time = time()
training = train(
source_cnn, train_loader, criterion, optimizer, args=args)
validation = validate(
source_cnn, test_loader, criterion, args=args)
log = 'Epoch {}/{} '.format(epoch_i, args.epochs)
log += '| Train/Loss {:.3f} Acc {:.3f} '.format(
training['loss'], training['acc'])
log += '| Val/Loss {:.3f} Acc {:.3f} '.format(
validation['loss'], validation['acc'])
log += 'Time {:.2f}s'.format(time() - start_time)
logger.info(log)
# save
is_best = (best_score is None or validation['acc'] > best_score)
best_score = validation['acc'] if is_best else best_score
state_dict = {
'model': source_cnn.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch_i,
'val/acc': best_score,
}
save(args.logdir, state_dict, is_best)
return source_cnn
def train_target_cnn(
source_cnn, target_cnn, discriminator,
criterion, optimizer, d_optimizer,
source_train_loader, target_train_loader, target_test_loader,
args=None
):
validation = validate(source_cnn, target_test_loader, criterion, args=args)
log_source = 'Source/Acc {:.3f} '.format(validation['acc'])
writer = SummaryWriter(args.logdir)
best_score = None
for epoch_i in range(1, 1 + args.epochs):
start_time = time()
training = adversarial(
source_cnn, target_cnn, discriminator,
source_train_loader, target_train_loader,
criterion, criterion,
optimizer, d_optimizer,
args=args
)
validation = validate(
target_cnn, target_test_loader, criterion, args=args)
validation2 = validate(
target_cnn, target_train_loader, criterion, args=args)
log = 'Epoch {}/{} '.format(epoch_i, args.epochs)
log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format(
training['d/loss'], training['target/loss'])
log += '[Val] Target/Loss {:.3f} Target/Acc {:.3f} '.format(
validation['loss'], validation['acc'])
log += log_source
log += '[Train] Target/Loss {:.3f} Target/Acc {:.3f} '.format(
validation2['loss'], validation2['acc'])
log += 'Time {:.2f}s'.format(time() - start_time)
logger.info(log)
# save
is_best = (best_score is None or validation['acc'] > best_score)
best_score = validation['acc'] if is_best else best_score
state_dict = {
'model': target_cnn.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch_i,
'val/acc': best_score,
}
save(args.logdir, state_dict, is_best)
# tensorboard
writer.add_scalar('Adv/D/Loss', training['d/loss'], epoch_i)
writer.add_scalar('Adv/Target/Loss', training['target/loss'], epoch_i)
writer.add_scalar('Val/Target/Loss', validation['loss'], epoch_i)
writer.add_scalar('Val/Target/Acc', validation['acc'], epoch_i)
writer.add_scalar('Train/Target/Acc', validation2['acc'], epoch_i)
def adversarial(
source_cnn, target_cnn, discriminator,
source_loader, target_loader,
criterion, d_criterion,
optimizer, d_optimizer,
args=None
):
source_cnn.eval()
target_cnn.encoder.train()
discriminator.train()
losses, d_losses = AverageMeter(), AverageMeter()
n_iters = min(len(source_loader), len(target_loader))
source_iter, target_iter = iter(source_loader), iter(target_loader)
for iter_i in range(n_iters):
source_data, source_target = source_iter.next()
target_data, target_target = target_iter.next()
source_data = source_data.to(args.device)
target_data = target_data.to(args.device)
bs = source_data.size(0)
D_input_source = source_cnn.encoder(source_data)
D_input_target = target_cnn.encoder(target_data)
D_target_source = torch.tensor(
[0] * bs, dtype=torch.long).to(args.device)
D_target_target = torch.tensor(
[1] * bs, dtype=torch.long).to(args.device)
# train Discriminator
D_output_source = discriminator(D_input_source)
D_output_target = discriminator(D_input_target)
D_output = torch.cat([D_output_source, D_output_target], dim=0)
D_target = torch.cat([D_target_source, D_target_target], dim=0)
d_loss = criterion(D_output, D_target)
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
d_losses.update(d_loss.item(), bs)
# train Target
D_input_target = target_cnn.encoder(target_data)
D_output_target = discriminator(D_input_target)
loss = criterion(D_output_target, D_target_source)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.update(loss.item(), bs)
return {'d/loss': d_losses.avg, 'target/loss': losses.avg}
def step(model, data, target, criterion, args):
data, target = data.to(args.device), target.to(args.device)
output = model(data)
loss = criterion(output, target)
return output, loss
def train(model, dataloader, criterion, optimizer, args=None):
model.train()
losses = AverageMeter()
targets, probas = [], []
for i, (data, target) in enumerate(dataloader):
bs = target.size(0)
output, loss = step(model, data, target, criterion, args)
output = torch.softmax(output, dim=1) # NOTE
losses.update(loss.item(), bs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
targets.extend(target.cpu().detach().numpy().tolist())
probas.extend(output.cpu().detach().numpy().tolist())
probas = np.asarray(probas)
preds = np.argmax(probas, axis=1)
acc = accuracy_score(targets, preds)
return {
'loss': losses.avg, 'acc': acc,
}
def validate(model, dataloader, criterion, args=None):
model.eval()
losses = AverageMeter()
targets, probas = [], []
with torch.no_grad():
for iter_i, (data, target) in enumerate(dataloader):
bs = target.size(0)
output, loss = step(model, data, target, criterion, args)
output = torch.softmax(output, dim=1) # NOTE: check
losses.update(loss.item(), bs)
targets.extend(target.cpu().numpy().tolist())
probas.extend(output.cpu().numpy().tolist())
probas = np.asarray(probas)
preds = np.argmax(probas, axis=1)
acc = accuracy_score(targets, preds)
return {
'loss': losses.avg, 'acc': acc,
}

48
utils.py

@ -0,0 +1,48 @@
import os
import shutil
import torch
def save(log_dir, state_dict, is_best):
checkpoint_path = os.path.join(log_dir, 'checkpoint.pt')
torch.save(state_dict, checkpoint_path)
if is_best:
best_model_path = os.path.join(log_dir, 'best_model.pt')
shutil.copyfile(checkpoint_path, best_model_path)
def get_logger(log_file):
from logging import getLogger, FileHandler, StreamHandler, Formatter, DEBUG, INFO # noqa
fh = FileHandler(log_file)
fh.setLevel(DEBUG)
sh = StreamHandler()
sh.setLevel(INFO)
for handler in [fh, sh]:
formatter = Formatter('%(asctime)s - %(message)s')
handler.setFormatter(formatter)
logger = getLogger('adda')
logger.setLevel(INFO)
logger.addHandler(fh)
logger.addHandler(sh)
return logger
class AverageMeter(object):
"""Computes and stores the average and current value
https://github.com/pytorch/examples/blob/master/imagenet/main.py#L296
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
Loading…
Cancel
Save