fnakamura
6 years ago
committed by
GitHub
9 changed files with 508 additions and 0 deletions
@ -1,2 +1,23 @@ |
|||||
# ADDA.PyTorch |
# ADDA.PyTorch |
||||
implement Adversarial Discriminative Domain Adapation in PyTorch |
implement Adversarial Discriminative Domain Adapation in PyTorch |
||||
|
|
||||
|
|
||||
|
## Example |
||||
|
``` |
||||
|
$ python train_source.py --logdir outputs |
||||
|
$ python main.py --logdir outputs --trained outputs/best_model.pt --slope 0.2 |
||||
|
``` |
||||
|
|
||||
|
|
||||
|
## Result |
||||
|
### SVHN -> MNIST |
||||
|
| | Paper | This Repro | |
||||
|
| --- | --- | --- | |
||||
|
| Source only | 0.601 | 0.659 | |
||||
|
| ADDA | 0.760 | ~0.83 | |
||||
|
|
||||
|
![adversarial](adversarial.png) |
||||
|
![target_domain](target_domain.png) |
||||
|
|
||||
|
## Resource |
||||
|
- https://arxiv.org/pdf/1702.05464.pdf |
||||
|
After Width: | Height: | Size: 100 KiB |
@ -0,0 +1,70 @@ |
|||||
|
import os |
||||
|
|
||||
|
import torch |
||||
|
from torch import nn, optim |
||||
|
from torch.utils.data import DataLoader |
||||
|
from torchvision.datasets import SVHN, MNIST |
||||
|
from torchvision import transforms |
||||
|
|
||||
|
from models import CNN, Discriminator |
||||
|
from trainer import train_target_cnn |
||||
|
from utils import get_logger |
||||
|
|
||||
|
|
||||
|
def run(args): |
||||
|
if not os.path.exists(args.logdir): |
||||
|
os.makedirs(args.logdir) |
||||
|
logger = get_logger(os.path.join(args.logdir, 'main.log')) |
||||
|
logger.info(args) |
||||
|
|
||||
|
# data |
||||
|
source_transform = transforms.Compose([ |
||||
|
# transforms.Grayscale(), |
||||
|
transforms.ToTensor()] |
||||
|
) |
||||
|
target_transform = transforms.Compose([ |
||||
|
transforms.Resize(32), |
||||
|
transforms.ToTensor(), |
||||
|
transforms.Lambda(lambda x: x.repeat(3, 1, 1)) |
||||
|
]) |
||||
|
source_dataset_train = SVHN( |
||||
|
'./input', 'train', transform=source_transform, download=True) |
||||
|
target_dataset_train = MNIST( |
||||
|
'./input', 'train', transform=target_transform, download=True) |
||||
|
target_dataset_test = MNIST( |
||||
|
'./input', 'test', transform=target_transform, download=True) |
||||
|
source_train_loader = DataLoader( |
||||
|
source_dataset_train, args.batch_size, shuffle=True, |
||||
|
drop_last=True, |
||||
|
num_workers=args.n_workers) |
||||
|
target_train_loader = DataLoader( |
||||
|
target_dataset_train, args.batch_size, shuffle=True, |
||||
|
drop_last=True, |
||||
|
num_workers=args.n_workers) |
||||
|
target_test_loader = DataLoader( |
||||
|
target_dataset_test, args.batch_size, shuffle=False, |
||||
|
num_workers=args.n_workers) |
||||
|
|
||||
|
# train source CNN |
||||
|
source_cnn = CNN(in_channels=args.in_channels).to(args.device) |
||||
|
if os.path.isfile(args.trained): |
||||
|
c = torch.load(args.trained) |
||||
|
source_cnn.load_state_dict(c['model']) |
||||
|
logger.info('Loaded `{}`'.format(args.trained)) |
||||
|
|
||||
|
# train target CNN |
||||
|
target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device) |
||||
|
target_cnn.load_state_dict(source_cnn.state_dict()) |
||||
|
discriminator = Discriminator(args=args).to(args.device) |
||||
|
criterion = nn.CrossEntropyLoss() |
||||
|
optimizer = optim.Adam( |
||||
|
target_cnn.encoder.parameters(), |
||||
|
lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) |
||||
|
d_optimizer = optim.Adam( |
||||
|
discriminator.parameters(), |
||||
|
lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) |
||||
|
train_target_cnn( |
||||
|
source_cnn, target_cnn, discriminator, |
||||
|
criterion, optimizer, d_optimizer, |
||||
|
source_train_loader, target_train_loader, target_test_loader, |
||||
|
args=args) |
@ -0,0 +1,24 @@ |
|||||
|
import argparse |
||||
|
import experiment |
||||
|
|
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
parser = argparse.ArgumentParser() |
||||
|
# NN |
||||
|
parser.add_argument('--in_channels', type=int, default=3) |
||||
|
parser.add_argument('--n_classes', type=int, default=10) |
||||
|
parser.add_argument('--trained', type=str, default='') |
||||
|
parser.add_argument('--slope', type=float, default=0.1) |
||||
|
# train |
||||
|
parser.add_argument('--lr', type=float, default=2e-4) |
||||
|
parser.add_argument('--weight_decay', type=float, default=2.5e-5) |
||||
|
parser.add_argument('--epochs', type=int, default=512) |
||||
|
parser.add_argument('--batch_size', type=int, default=128) |
||||
|
parser.add_argument('--betas', type=float, nargs='+', default=(.5, .999)) |
||||
|
# misc |
||||
|
parser.add_argument('--device', type=str, default='cuda:0') |
||||
|
parser.add_argument('--n_workers', type=int, default=0) |
||||
|
parser.add_argument('--logdir', type=str, default='outputs/garbage') |
||||
|
parser.add_argument('--message', '-m', type=str, default='') |
||||
|
args, unknown = parser.parse_known_args() |
||||
|
experiment.run(args) |
@ -0,0 +1,80 @@ |
|||||
|
from torch import nn |
||||
|
import torch.nn.functional as F |
||||
|
|
||||
|
|
||||
|
class Encoder(nn.Module): |
||||
|
def __init__(self, in_channels=1, h=256, dropout=0.5): |
||||
|
super(Encoder, self).__init__() |
||||
|
self.conv1 = nn.Conv2d(in_channels, 20, kernel_size=5, stride=1) |
||||
|
self.conv2 = nn.Conv2d(20, 50, kernel_size=5, stride=1) |
||||
|
self.bn1 = nn.BatchNorm2d(20) |
||||
|
self.bn2 = nn.BatchNorm2d(50) |
||||
|
# self.conv3 = nn.Conv2d(16, 120, kernel_size=4, stride=1) |
||||
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
||||
|
self.relu = nn.ReLU() |
||||
|
# self.dropout1 = nn.Dropout2d(dropout) |
||||
|
self.dropout = nn.Dropout(dropout) |
||||
|
self.fc = nn.Linear(1250, 500) |
||||
|
|
||||
|
for m in self.modules(): |
||||
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
||||
|
nn.init.kaiming_normal_(m.weight) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
bs = x.size(0) |
||||
|
x = self.pool(self.relu(self.bn1(self.conv1(x)))) |
||||
|
x = self.pool(self.relu(self.bn2(self.conv2(x)))) |
||||
|
# x = self.dropout1(self.relu(self.conv3(x))) |
||||
|
# x = self.relu(self.conv3(x)) |
||||
|
x = x.view(bs, -1) |
||||
|
x = self.dropout(x) |
||||
|
x = self.fc(x) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class Classifier(nn.Module): |
||||
|
def __init__(self, n_classes, dropout=0.5): |
||||
|
super(Classifier, self).__init__() |
||||
|
self.l1 = nn.Linear(500, n_classes) |
||||
|
|
||||
|
for m in self.modules(): |
||||
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
||||
|
nn.init.kaiming_normal_(m.weight) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
x = self.l1(x) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class CNN(nn.Module): |
||||
|
def __init__(self, in_channels=1, n_classes=10, target=False): |
||||
|
super(CNN, self).__init__() |
||||
|
self.encoder = Encoder(in_channels=in_channels) |
||||
|
self.classifier = Classifier(n_classes) |
||||
|
if target: |
||||
|
for param in self.classifier.parameters(): |
||||
|
param.requires_grad = False |
||||
|
|
||||
|
def forward(self, x): |
||||
|
x = self.encoder(x) |
||||
|
x = self.classifier(x) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class Discriminator(nn.Module): |
||||
|
def __init__(self, h=500, args=None): |
||||
|
super(Discriminator, self).__init__() |
||||
|
self.l1 = nn.Linear(500, h) |
||||
|
self.l2 = nn.Linear(h, h) |
||||
|
self.l3 = nn.Linear(h, 2) |
||||
|
self.slope = args.slope |
||||
|
|
||||
|
for m in self.modules(): |
||||
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
||||
|
nn.init.kaiming_normal_(m.weight) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
x = F.leaky_relu(self.l1(x), self.slope) |
||||
|
x = F.leaky_relu(self.l2(x), self.slope) |
||||
|
x = self.l3(x) |
||||
|
return x |
After Width: | Height: | Size: 145 KiB |
@ -0,0 +1,65 @@ |
|||||
|
import argparse |
||||
|
import os |
||||
|
|
||||
|
from torch import nn, optim |
||||
|
from torch.utils.data import DataLoader |
||||
|
from torchvision.datasets import SVHN |
||||
|
from torchvision import transforms |
||||
|
|
||||
|
from models import CNN |
||||
|
from trainer import train_source_cnn |
||||
|
from utils import get_logger |
||||
|
|
||||
|
|
||||
|
def main(args): |
||||
|
if not os.path.exists(args.logdir): |
||||
|
os.makedirs(args.logdir) |
||||
|
logger = get_logger(os.path.join(args.logdir, 'train_source.log')) |
||||
|
logger.info(args) |
||||
|
|
||||
|
# data |
||||
|
source_transform = transforms.Compose([ |
||||
|
transforms.ToTensor()] |
||||
|
) |
||||
|
source_dataset_train = SVHN( |
||||
|
'./input', 'train', transform=source_transform, download=True) |
||||
|
source_dataset_test = SVHN( |
||||
|
'./input', 'test', transform=source_transform, download=True) |
||||
|
source_train_loader = DataLoader( |
||||
|
source_dataset_train, args.batch_size, shuffle=True, |
||||
|
drop_last=True, |
||||
|
num_workers=args.n_workers) |
||||
|
source_test_loader = DataLoader( |
||||
|
source_dataset_test, args.batch_size, shuffle=False, |
||||
|
num_workers=args.n_workers) |
||||
|
|
||||
|
# train source CNN |
||||
|
source_cnn = CNN(in_channels=args.in_channels).to(args.device) |
||||
|
criterion = nn.CrossEntropyLoss() |
||||
|
optimizer = optim.Adam( |
||||
|
source_cnn.parameters(), |
||||
|
lr=args.lr, weight_decay=args.weight_decay) |
||||
|
source_cnn = train_source_cnn( |
||||
|
source_cnn, source_train_loader, source_test_loader, |
||||
|
criterion, optimizer, args=args) |
||||
|
|
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
parser = argparse.ArgumentParser() |
||||
|
# NN |
||||
|
parser.add_argument('--in_channels', type=int, default=3) |
||||
|
parser.add_argument('--n_classes', type=int, default=10) |
||||
|
parser.add_argument('--trained', type=str, default='') |
||||
|
parser.add_argument('--slope', type=float, default=0.2) |
||||
|
# train |
||||
|
parser.add_argument('--lr', type=float, default=1e-3) |
||||
|
parser.add_argument('--weight_decay', type=float, default=2.5e-5) |
||||
|
parser.add_argument('--epochs', type=int, default=50) |
||||
|
parser.add_argument('--batch_size', type=int, default=128) |
||||
|
# misc |
||||
|
parser.add_argument('--device', type=str, default='cuda:0') |
||||
|
parser.add_argument('--n_workers', type=int, default=0) |
||||
|
parser.add_argument('--logdir', type=str, default='outputs/garbage') |
||||
|
parser.add_argument('--message', '-m', type=str, default='') |
||||
|
args, unknown = parser.parse_known_args() |
||||
|
main(args) |
@ -0,0 +1,200 @@ |
|||||
|
from logging import getLogger |
||||
|
from time import time |
||||
|
|
||||
|
import numpy as np |
||||
|
from sklearn.metrics import accuracy_score |
||||
|
from tensorboardX import SummaryWriter |
||||
|
import torch |
||||
|
|
||||
|
from utils import AverageMeter, save |
||||
|
|
||||
|
|
||||
|
logger = getLogger('adda.trainer') |
||||
|
|
||||
|
|
||||
|
def train_source_cnn( |
||||
|
source_cnn, train_loader, test_loader, criterion, optimizer, |
||||
|
args=None |
||||
|
): |
||||
|
best_score = None |
||||
|
for epoch_i in range(1, 1 + args.epochs): |
||||
|
start_time = time() |
||||
|
training = train( |
||||
|
source_cnn, train_loader, criterion, optimizer, args=args) |
||||
|
validation = validate( |
||||
|
source_cnn, test_loader, criterion, args=args) |
||||
|
log = 'Epoch {}/{} '.format(epoch_i, args.epochs) |
||||
|
log += '| Train/Loss {:.3f} Acc {:.3f} '.format( |
||||
|
training['loss'], training['acc']) |
||||
|
log += '| Val/Loss {:.3f} Acc {:.3f} '.format( |
||||
|
validation['loss'], validation['acc']) |
||||
|
log += 'Time {:.2f}s'.format(time() - start_time) |
||||
|
logger.info(log) |
||||
|
|
||||
|
# save |
||||
|
is_best = (best_score is None or validation['acc'] > best_score) |
||||
|
best_score = validation['acc'] if is_best else best_score |
||||
|
state_dict = { |
||||
|
'model': source_cnn.state_dict(), |
||||
|
'optimizer': optimizer.state_dict(), |
||||
|
'epoch': epoch_i, |
||||
|
'val/acc': best_score, |
||||
|
} |
||||
|
save(args.logdir, state_dict, is_best) |
||||
|
|
||||
|
return source_cnn |
||||
|
|
||||
|
|
||||
|
def train_target_cnn( |
||||
|
source_cnn, target_cnn, discriminator, |
||||
|
criterion, optimizer, d_optimizer, |
||||
|
source_train_loader, target_train_loader, target_test_loader, |
||||
|
args=None |
||||
|
): |
||||
|
validation = validate(source_cnn, target_test_loader, criterion, args=args) |
||||
|
log_source = 'Source/Acc {:.3f} '.format(validation['acc']) |
||||
|
|
||||
|
writer = SummaryWriter(args.logdir) |
||||
|
best_score = None |
||||
|
for epoch_i in range(1, 1 + args.epochs): |
||||
|
start_time = time() |
||||
|
training = adversarial( |
||||
|
source_cnn, target_cnn, discriminator, |
||||
|
source_train_loader, target_train_loader, |
||||
|
criterion, criterion, |
||||
|
optimizer, d_optimizer, |
||||
|
args=args |
||||
|
) |
||||
|
validation = validate( |
||||
|
target_cnn, target_test_loader, criterion, args=args) |
||||
|
validation2 = validate( |
||||
|
target_cnn, target_train_loader, criterion, args=args) |
||||
|
log = 'Epoch {}/{} '.format(epoch_i, args.epochs) |
||||
|
log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format( |
||||
|
training['d/loss'], training['target/loss']) |
||||
|
log += '[Val] Target/Loss {:.3f} Target/Acc {:.3f} '.format( |
||||
|
validation['loss'], validation['acc']) |
||||
|
log += log_source |
||||
|
log += '[Train] Target/Loss {:.3f} Target/Acc {:.3f} '.format( |
||||
|
validation2['loss'], validation2['acc']) |
||||
|
log += 'Time {:.2f}s'.format(time() - start_time) |
||||
|
logger.info(log) |
||||
|
|
||||
|
# save |
||||
|
is_best = (best_score is None or validation['acc'] > best_score) |
||||
|
best_score = validation['acc'] if is_best else best_score |
||||
|
state_dict = { |
||||
|
'model': target_cnn.state_dict(), |
||||
|
'optimizer': optimizer.state_dict(), |
||||
|
'epoch': epoch_i, |
||||
|
'val/acc': best_score, |
||||
|
} |
||||
|
save(args.logdir, state_dict, is_best) |
||||
|
|
||||
|
# tensorboard |
||||
|
writer.add_scalar('Adv/D/Loss', training['d/loss'], epoch_i) |
||||
|
writer.add_scalar('Adv/Target/Loss', training['target/loss'], epoch_i) |
||||
|
writer.add_scalar('Val/Target/Loss', validation['loss'], epoch_i) |
||||
|
writer.add_scalar('Val/Target/Acc', validation['acc'], epoch_i) |
||||
|
writer.add_scalar('Train/Target/Acc', validation2['acc'], epoch_i) |
||||
|
|
||||
|
|
||||
|
def adversarial( |
||||
|
source_cnn, target_cnn, discriminator, |
||||
|
source_loader, target_loader, |
||||
|
criterion, d_criterion, |
||||
|
optimizer, d_optimizer, |
||||
|
args=None |
||||
|
): |
||||
|
source_cnn.eval() |
||||
|
target_cnn.encoder.train() |
||||
|
discriminator.train() |
||||
|
|
||||
|
losses, d_losses = AverageMeter(), AverageMeter() |
||||
|
n_iters = min(len(source_loader), len(target_loader)) |
||||
|
source_iter, target_iter = iter(source_loader), iter(target_loader) |
||||
|
for iter_i in range(n_iters): |
||||
|
source_data, source_target = source_iter.next() |
||||
|
target_data, target_target = target_iter.next() |
||||
|
source_data = source_data.to(args.device) |
||||
|
target_data = target_data.to(args.device) |
||||
|
bs = source_data.size(0) |
||||
|
|
||||
|
D_input_source = source_cnn.encoder(source_data) |
||||
|
D_input_target = target_cnn.encoder(target_data) |
||||
|
D_target_source = torch.tensor( |
||||
|
[0] * bs, dtype=torch.long).to(args.device) |
||||
|
D_target_target = torch.tensor( |
||||
|
[1] * bs, dtype=torch.long).to(args.device) |
||||
|
|
||||
|
# train Discriminator |
||||
|
D_output_source = discriminator(D_input_source) |
||||
|
D_output_target = discriminator(D_input_target) |
||||
|
D_output = torch.cat([D_output_source, D_output_target], dim=0) |
||||
|
D_target = torch.cat([D_target_source, D_target_target], dim=0) |
||||
|
d_loss = criterion(D_output, D_target) |
||||
|
d_optimizer.zero_grad() |
||||
|
d_loss.backward() |
||||
|
d_optimizer.step() |
||||
|
d_losses.update(d_loss.item(), bs) |
||||
|
|
||||
|
# train Target |
||||
|
D_input_target = target_cnn.encoder(target_data) |
||||
|
D_output_target = discriminator(D_input_target) |
||||
|
loss = criterion(D_output_target, D_target_source) |
||||
|
optimizer.zero_grad() |
||||
|
loss.backward() |
||||
|
optimizer.step() |
||||
|
losses.update(loss.item(), bs) |
||||
|
return {'d/loss': d_losses.avg, 'target/loss': losses.avg} |
||||
|
|
||||
|
|
||||
|
def step(model, data, target, criterion, args): |
||||
|
data, target = data.to(args.device), target.to(args.device) |
||||
|
output = model(data) |
||||
|
loss = criterion(output, target) |
||||
|
return output, loss |
||||
|
|
||||
|
|
||||
|
def train(model, dataloader, criterion, optimizer, args=None): |
||||
|
model.train() |
||||
|
losses = AverageMeter() |
||||
|
targets, probas = [], [] |
||||
|
for i, (data, target) in enumerate(dataloader): |
||||
|
bs = target.size(0) |
||||
|
output, loss = step(model, data, target, criterion, args) |
||||
|
output = torch.softmax(output, dim=1) # NOTE |
||||
|
losses.update(loss.item(), bs) |
||||
|
|
||||
|
optimizer.zero_grad() |
||||
|
loss.backward() |
||||
|
optimizer.step() |
||||
|
|
||||
|
targets.extend(target.cpu().detach().numpy().tolist()) |
||||
|
probas.extend(output.cpu().detach().numpy().tolist()) |
||||
|
probas = np.asarray(probas) |
||||
|
preds = np.argmax(probas, axis=1) |
||||
|
acc = accuracy_score(targets, preds) |
||||
|
return { |
||||
|
'loss': losses.avg, 'acc': acc, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
def validate(model, dataloader, criterion, args=None): |
||||
|
model.eval() |
||||
|
losses = AverageMeter() |
||||
|
targets, probas = [], [] |
||||
|
with torch.no_grad(): |
||||
|
for iter_i, (data, target) in enumerate(dataloader): |
||||
|
bs = target.size(0) |
||||
|
output, loss = step(model, data, target, criterion, args) |
||||
|
output = torch.softmax(output, dim=1) # NOTE: check |
||||
|
losses.update(loss.item(), bs) |
||||
|
targets.extend(target.cpu().numpy().tolist()) |
||||
|
probas.extend(output.cpu().numpy().tolist()) |
||||
|
probas = np.asarray(probas) |
||||
|
preds = np.argmax(probas, axis=1) |
||||
|
acc = accuracy_score(targets, preds) |
||||
|
return { |
||||
|
'loss': losses.avg, 'acc': acc, |
||||
|
} |
@ -0,0 +1,48 @@ |
|||||
|
import os |
||||
|
import shutil |
||||
|
import torch |
||||
|
|
||||
|
|
||||
|
def save(log_dir, state_dict, is_best): |
||||
|
checkpoint_path = os.path.join(log_dir, 'checkpoint.pt') |
||||
|
torch.save(state_dict, checkpoint_path) |
||||
|
if is_best: |
||||
|
best_model_path = os.path.join(log_dir, 'best_model.pt') |
||||
|
shutil.copyfile(checkpoint_path, best_model_path) |
||||
|
|
||||
|
|
||||
|
def get_logger(log_file): |
||||
|
from logging import getLogger, FileHandler, StreamHandler, Formatter, DEBUG, INFO # noqa |
||||
|
fh = FileHandler(log_file) |
||||
|
fh.setLevel(DEBUG) |
||||
|
sh = StreamHandler() |
||||
|
sh.setLevel(INFO) |
||||
|
for handler in [fh, sh]: |
||||
|
formatter = Formatter('%(asctime)s - %(message)s') |
||||
|
handler.setFormatter(formatter) |
||||
|
logger = getLogger('adda') |
||||
|
logger.setLevel(INFO) |
||||
|
logger.addHandler(fh) |
||||
|
logger.addHandler(sh) |
||||
|
return logger |
||||
|
|
||||
|
|
||||
|
class AverageMeter(object): |
||||
|
"""Computes and stores the average and current value |
||||
|
https://github.com/pytorch/examples/blob/master/imagenet/main.py#L296 |
||||
|
""" |
||||
|
|
||||
|
def __init__(self): |
||||
|
self.reset() |
||||
|
|
||||
|
def reset(self): |
||||
|
self.val = 0 |
||||
|
self.avg = 0 |
||||
|
self.sum = 0 |
||||
|
self.count = 0 |
||||
|
|
||||
|
def update(self, val, n=1): |
||||
|
self.val = val |
||||
|
self.sum += val * n |
||||
|
self.count += n |
||||
|
self.avg = self.sum / self.count |
Loading…
Reference in new issue