fnakamura
6 years ago
committed by
GitHub
9 changed files with 508 additions and 0 deletions
@ -1,2 +1,23 @@ |
|||
# ADDA.PyTorch |
|||
implement Adversarial Discriminative Domain Adapation in PyTorch |
|||
|
|||
|
|||
## Example |
|||
``` |
|||
$ python train_source.py --logdir outputs |
|||
$ python main.py --logdir outputs --trained outputs/best_model.pt --slope 0.2 |
|||
``` |
|||
|
|||
|
|||
## Result |
|||
### SVHN -> MNIST |
|||
| | Paper | This Repro | |
|||
| --- | --- | --- | |
|||
| Source only | 0.601 | 0.659 | |
|||
| ADDA | 0.760 | ~0.83 | |
|||
|
|||
![adversarial](adversarial.png) |
|||
![target_domain](target_domain.png) |
|||
|
|||
## Resource |
|||
- https://arxiv.org/pdf/1702.05464.pdf |
|||
|
After Width: | Height: | Size: 100 KiB |
@ -0,0 +1,70 @@ |
|||
import os |
|||
|
|||
import torch |
|||
from torch import nn, optim |
|||
from torch.utils.data import DataLoader |
|||
from torchvision.datasets import SVHN, MNIST |
|||
from torchvision import transforms |
|||
|
|||
from models import CNN, Discriminator |
|||
from trainer import train_target_cnn |
|||
from utils import get_logger |
|||
|
|||
|
|||
def run(args): |
|||
if not os.path.exists(args.logdir): |
|||
os.makedirs(args.logdir) |
|||
logger = get_logger(os.path.join(args.logdir, 'main.log')) |
|||
logger.info(args) |
|||
|
|||
# data |
|||
source_transform = transforms.Compose([ |
|||
# transforms.Grayscale(), |
|||
transforms.ToTensor()] |
|||
) |
|||
target_transform = transforms.Compose([ |
|||
transforms.Resize(32), |
|||
transforms.ToTensor(), |
|||
transforms.Lambda(lambda x: x.repeat(3, 1, 1)) |
|||
]) |
|||
source_dataset_train = SVHN( |
|||
'./input', 'train', transform=source_transform, download=True) |
|||
target_dataset_train = MNIST( |
|||
'./input', 'train', transform=target_transform, download=True) |
|||
target_dataset_test = MNIST( |
|||
'./input', 'test', transform=target_transform, download=True) |
|||
source_train_loader = DataLoader( |
|||
source_dataset_train, args.batch_size, shuffle=True, |
|||
drop_last=True, |
|||
num_workers=args.n_workers) |
|||
target_train_loader = DataLoader( |
|||
target_dataset_train, args.batch_size, shuffle=True, |
|||
drop_last=True, |
|||
num_workers=args.n_workers) |
|||
target_test_loader = DataLoader( |
|||
target_dataset_test, args.batch_size, shuffle=False, |
|||
num_workers=args.n_workers) |
|||
|
|||
# train source CNN |
|||
source_cnn = CNN(in_channels=args.in_channels).to(args.device) |
|||
if os.path.isfile(args.trained): |
|||
c = torch.load(args.trained) |
|||
source_cnn.load_state_dict(c['model']) |
|||
logger.info('Loaded `{}`'.format(args.trained)) |
|||
|
|||
# train target CNN |
|||
target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device) |
|||
target_cnn.load_state_dict(source_cnn.state_dict()) |
|||
discriminator = Discriminator(args=args).to(args.device) |
|||
criterion = nn.CrossEntropyLoss() |
|||
optimizer = optim.Adam( |
|||
target_cnn.encoder.parameters(), |
|||
lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) |
|||
d_optimizer = optim.Adam( |
|||
discriminator.parameters(), |
|||
lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) |
|||
train_target_cnn( |
|||
source_cnn, target_cnn, discriminator, |
|||
criterion, optimizer, d_optimizer, |
|||
source_train_loader, target_train_loader, target_test_loader, |
|||
args=args) |
@ -0,0 +1,24 @@ |
|||
import argparse |
|||
import experiment |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
parser = argparse.ArgumentParser() |
|||
# NN |
|||
parser.add_argument('--in_channels', type=int, default=3) |
|||
parser.add_argument('--n_classes', type=int, default=10) |
|||
parser.add_argument('--trained', type=str, default='') |
|||
parser.add_argument('--slope', type=float, default=0.1) |
|||
# train |
|||
parser.add_argument('--lr', type=float, default=2e-4) |
|||
parser.add_argument('--weight_decay', type=float, default=2.5e-5) |
|||
parser.add_argument('--epochs', type=int, default=512) |
|||
parser.add_argument('--batch_size', type=int, default=128) |
|||
parser.add_argument('--betas', type=float, nargs='+', default=(.5, .999)) |
|||
# misc |
|||
parser.add_argument('--device', type=str, default='cuda:0') |
|||
parser.add_argument('--n_workers', type=int, default=0) |
|||
parser.add_argument('--logdir', type=str, default='outputs/garbage') |
|||
parser.add_argument('--message', '-m', type=str, default='') |
|||
args, unknown = parser.parse_known_args() |
|||
experiment.run(args) |
@ -0,0 +1,80 @@ |
|||
from torch import nn |
|||
import torch.nn.functional as F |
|||
|
|||
|
|||
class Encoder(nn.Module): |
|||
def __init__(self, in_channels=1, h=256, dropout=0.5): |
|||
super(Encoder, self).__init__() |
|||
self.conv1 = nn.Conv2d(in_channels, 20, kernel_size=5, stride=1) |
|||
self.conv2 = nn.Conv2d(20, 50, kernel_size=5, stride=1) |
|||
self.bn1 = nn.BatchNorm2d(20) |
|||
self.bn2 = nn.BatchNorm2d(50) |
|||
# self.conv3 = nn.Conv2d(16, 120, kernel_size=4, stride=1) |
|||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
|||
self.relu = nn.ReLU() |
|||
# self.dropout1 = nn.Dropout2d(dropout) |
|||
self.dropout = nn.Dropout(dropout) |
|||
self.fc = nn.Linear(1250, 500) |
|||
|
|||
for m in self.modules(): |
|||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
|||
nn.init.kaiming_normal_(m.weight) |
|||
|
|||
def forward(self, x): |
|||
bs = x.size(0) |
|||
x = self.pool(self.relu(self.bn1(self.conv1(x)))) |
|||
x = self.pool(self.relu(self.bn2(self.conv2(x)))) |
|||
# x = self.dropout1(self.relu(self.conv3(x))) |
|||
# x = self.relu(self.conv3(x)) |
|||
x = x.view(bs, -1) |
|||
x = self.dropout(x) |
|||
x = self.fc(x) |
|||
return x |
|||
|
|||
|
|||
class Classifier(nn.Module): |
|||
def __init__(self, n_classes, dropout=0.5): |
|||
super(Classifier, self).__init__() |
|||
self.l1 = nn.Linear(500, n_classes) |
|||
|
|||
for m in self.modules(): |
|||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
|||
nn.init.kaiming_normal_(m.weight) |
|||
|
|||
def forward(self, x): |
|||
x = self.l1(x) |
|||
return x |
|||
|
|||
|
|||
class CNN(nn.Module): |
|||
def __init__(self, in_channels=1, n_classes=10, target=False): |
|||
super(CNN, self).__init__() |
|||
self.encoder = Encoder(in_channels=in_channels) |
|||
self.classifier = Classifier(n_classes) |
|||
if target: |
|||
for param in self.classifier.parameters(): |
|||
param.requires_grad = False |
|||
|
|||
def forward(self, x): |
|||
x = self.encoder(x) |
|||
x = self.classifier(x) |
|||
return x |
|||
|
|||
|
|||
class Discriminator(nn.Module): |
|||
def __init__(self, h=500, args=None): |
|||
super(Discriminator, self).__init__() |
|||
self.l1 = nn.Linear(500, h) |
|||
self.l2 = nn.Linear(h, h) |
|||
self.l3 = nn.Linear(h, 2) |
|||
self.slope = args.slope |
|||
|
|||
for m in self.modules(): |
|||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
|||
nn.init.kaiming_normal_(m.weight) |
|||
|
|||
def forward(self, x): |
|||
x = F.leaky_relu(self.l1(x), self.slope) |
|||
x = F.leaky_relu(self.l2(x), self.slope) |
|||
x = self.l3(x) |
|||
return x |
After Width: | Height: | Size: 145 KiB |
@ -0,0 +1,65 @@ |
|||
import argparse |
|||
import os |
|||
|
|||
from torch import nn, optim |
|||
from torch.utils.data import DataLoader |
|||
from torchvision.datasets import SVHN |
|||
from torchvision import transforms |
|||
|
|||
from models import CNN |
|||
from trainer import train_source_cnn |
|||
from utils import get_logger |
|||
|
|||
|
|||
def main(args): |
|||
if not os.path.exists(args.logdir): |
|||
os.makedirs(args.logdir) |
|||
logger = get_logger(os.path.join(args.logdir, 'train_source.log')) |
|||
logger.info(args) |
|||
|
|||
# data |
|||
source_transform = transforms.Compose([ |
|||
transforms.ToTensor()] |
|||
) |
|||
source_dataset_train = SVHN( |
|||
'./input', 'train', transform=source_transform, download=True) |
|||
source_dataset_test = SVHN( |
|||
'./input', 'test', transform=source_transform, download=True) |
|||
source_train_loader = DataLoader( |
|||
source_dataset_train, args.batch_size, shuffle=True, |
|||
drop_last=True, |
|||
num_workers=args.n_workers) |
|||
source_test_loader = DataLoader( |
|||
source_dataset_test, args.batch_size, shuffle=False, |
|||
num_workers=args.n_workers) |
|||
|
|||
# train source CNN |
|||
source_cnn = CNN(in_channels=args.in_channels).to(args.device) |
|||
criterion = nn.CrossEntropyLoss() |
|||
optimizer = optim.Adam( |
|||
source_cnn.parameters(), |
|||
lr=args.lr, weight_decay=args.weight_decay) |
|||
source_cnn = train_source_cnn( |
|||
source_cnn, source_train_loader, source_test_loader, |
|||
criterion, optimizer, args=args) |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
parser = argparse.ArgumentParser() |
|||
# NN |
|||
parser.add_argument('--in_channels', type=int, default=3) |
|||
parser.add_argument('--n_classes', type=int, default=10) |
|||
parser.add_argument('--trained', type=str, default='') |
|||
parser.add_argument('--slope', type=float, default=0.2) |
|||
# train |
|||
parser.add_argument('--lr', type=float, default=1e-3) |
|||
parser.add_argument('--weight_decay', type=float, default=2.5e-5) |
|||
parser.add_argument('--epochs', type=int, default=50) |
|||
parser.add_argument('--batch_size', type=int, default=128) |
|||
# misc |
|||
parser.add_argument('--device', type=str, default='cuda:0') |
|||
parser.add_argument('--n_workers', type=int, default=0) |
|||
parser.add_argument('--logdir', type=str, default='outputs/garbage') |
|||
parser.add_argument('--message', '-m', type=str, default='') |
|||
args, unknown = parser.parse_known_args() |
|||
main(args) |
@ -0,0 +1,200 @@ |
|||
from logging import getLogger |
|||
from time import time |
|||
|
|||
import numpy as np |
|||
from sklearn.metrics import accuracy_score |
|||
from tensorboardX import SummaryWriter |
|||
import torch |
|||
|
|||
from utils import AverageMeter, save |
|||
|
|||
|
|||
logger = getLogger('adda.trainer') |
|||
|
|||
|
|||
def train_source_cnn( |
|||
source_cnn, train_loader, test_loader, criterion, optimizer, |
|||
args=None |
|||
): |
|||
best_score = None |
|||
for epoch_i in range(1, 1 + args.epochs): |
|||
start_time = time() |
|||
training = train( |
|||
source_cnn, train_loader, criterion, optimizer, args=args) |
|||
validation = validate( |
|||
source_cnn, test_loader, criterion, args=args) |
|||
log = 'Epoch {}/{} '.format(epoch_i, args.epochs) |
|||
log += '| Train/Loss {:.3f} Acc {:.3f} '.format( |
|||
training['loss'], training['acc']) |
|||
log += '| Val/Loss {:.3f} Acc {:.3f} '.format( |
|||
validation['loss'], validation['acc']) |
|||
log += 'Time {:.2f}s'.format(time() - start_time) |
|||
logger.info(log) |
|||
|
|||
# save |
|||
is_best = (best_score is None or validation['acc'] > best_score) |
|||
best_score = validation['acc'] if is_best else best_score |
|||
state_dict = { |
|||
'model': source_cnn.state_dict(), |
|||
'optimizer': optimizer.state_dict(), |
|||
'epoch': epoch_i, |
|||
'val/acc': best_score, |
|||
} |
|||
save(args.logdir, state_dict, is_best) |
|||
|
|||
return source_cnn |
|||
|
|||
|
|||
def train_target_cnn( |
|||
source_cnn, target_cnn, discriminator, |
|||
criterion, optimizer, d_optimizer, |
|||
source_train_loader, target_train_loader, target_test_loader, |
|||
args=None |
|||
): |
|||
validation = validate(source_cnn, target_test_loader, criterion, args=args) |
|||
log_source = 'Source/Acc {:.3f} '.format(validation['acc']) |
|||
|
|||
writer = SummaryWriter(args.logdir) |
|||
best_score = None |
|||
for epoch_i in range(1, 1 + args.epochs): |
|||
start_time = time() |
|||
training = adversarial( |
|||
source_cnn, target_cnn, discriminator, |
|||
source_train_loader, target_train_loader, |
|||
criterion, criterion, |
|||
optimizer, d_optimizer, |
|||
args=args |
|||
) |
|||
validation = validate( |
|||
target_cnn, target_test_loader, criterion, args=args) |
|||
validation2 = validate( |
|||
target_cnn, target_train_loader, criterion, args=args) |
|||
log = 'Epoch {}/{} '.format(epoch_i, args.epochs) |
|||
log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format( |
|||
training['d/loss'], training['target/loss']) |
|||
log += '[Val] Target/Loss {:.3f} Target/Acc {:.3f} '.format( |
|||
validation['loss'], validation['acc']) |
|||
log += log_source |
|||
log += '[Train] Target/Loss {:.3f} Target/Acc {:.3f} '.format( |
|||
validation2['loss'], validation2['acc']) |
|||
log += 'Time {:.2f}s'.format(time() - start_time) |
|||
logger.info(log) |
|||
|
|||
# save |
|||
is_best = (best_score is None or validation['acc'] > best_score) |
|||
best_score = validation['acc'] if is_best else best_score |
|||
state_dict = { |
|||
'model': target_cnn.state_dict(), |
|||
'optimizer': optimizer.state_dict(), |
|||
'epoch': epoch_i, |
|||
'val/acc': best_score, |
|||
} |
|||
save(args.logdir, state_dict, is_best) |
|||
|
|||
# tensorboard |
|||
writer.add_scalar('Adv/D/Loss', training['d/loss'], epoch_i) |
|||
writer.add_scalar('Adv/Target/Loss', training['target/loss'], epoch_i) |
|||
writer.add_scalar('Val/Target/Loss', validation['loss'], epoch_i) |
|||
writer.add_scalar('Val/Target/Acc', validation['acc'], epoch_i) |
|||
writer.add_scalar('Train/Target/Acc', validation2['acc'], epoch_i) |
|||
|
|||
|
|||
def adversarial( |
|||
source_cnn, target_cnn, discriminator, |
|||
source_loader, target_loader, |
|||
criterion, d_criterion, |
|||
optimizer, d_optimizer, |
|||
args=None |
|||
): |
|||
source_cnn.eval() |
|||
target_cnn.encoder.train() |
|||
discriminator.train() |
|||
|
|||
losses, d_losses = AverageMeter(), AverageMeter() |
|||
n_iters = min(len(source_loader), len(target_loader)) |
|||
source_iter, target_iter = iter(source_loader), iter(target_loader) |
|||
for iter_i in range(n_iters): |
|||
source_data, source_target = source_iter.next() |
|||
target_data, target_target = target_iter.next() |
|||
source_data = source_data.to(args.device) |
|||
target_data = target_data.to(args.device) |
|||
bs = source_data.size(0) |
|||
|
|||
D_input_source = source_cnn.encoder(source_data) |
|||
D_input_target = target_cnn.encoder(target_data) |
|||
D_target_source = torch.tensor( |
|||
[0] * bs, dtype=torch.long).to(args.device) |
|||
D_target_target = torch.tensor( |
|||
[1] * bs, dtype=torch.long).to(args.device) |
|||
|
|||
# train Discriminator |
|||
D_output_source = discriminator(D_input_source) |
|||
D_output_target = discriminator(D_input_target) |
|||
D_output = torch.cat([D_output_source, D_output_target], dim=0) |
|||
D_target = torch.cat([D_target_source, D_target_target], dim=0) |
|||
d_loss = criterion(D_output, D_target) |
|||
d_optimizer.zero_grad() |
|||
d_loss.backward() |
|||
d_optimizer.step() |
|||
d_losses.update(d_loss.item(), bs) |
|||
|
|||
# train Target |
|||
D_input_target = target_cnn.encoder(target_data) |
|||
D_output_target = discriminator(D_input_target) |
|||
loss = criterion(D_output_target, D_target_source) |
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
losses.update(loss.item(), bs) |
|||
return {'d/loss': d_losses.avg, 'target/loss': losses.avg} |
|||
|
|||
|
|||
def step(model, data, target, criterion, args): |
|||
data, target = data.to(args.device), target.to(args.device) |
|||
output = model(data) |
|||
loss = criterion(output, target) |
|||
return output, loss |
|||
|
|||
|
|||
def train(model, dataloader, criterion, optimizer, args=None): |
|||
model.train() |
|||
losses = AverageMeter() |
|||
targets, probas = [], [] |
|||
for i, (data, target) in enumerate(dataloader): |
|||
bs = target.size(0) |
|||
output, loss = step(model, data, target, criterion, args) |
|||
output = torch.softmax(output, dim=1) # NOTE |
|||
losses.update(loss.item(), bs) |
|||
|
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
|
|||
targets.extend(target.cpu().detach().numpy().tolist()) |
|||
probas.extend(output.cpu().detach().numpy().tolist()) |
|||
probas = np.asarray(probas) |
|||
preds = np.argmax(probas, axis=1) |
|||
acc = accuracy_score(targets, preds) |
|||
return { |
|||
'loss': losses.avg, 'acc': acc, |
|||
} |
|||
|
|||
|
|||
def validate(model, dataloader, criterion, args=None): |
|||
model.eval() |
|||
losses = AverageMeter() |
|||
targets, probas = [], [] |
|||
with torch.no_grad(): |
|||
for iter_i, (data, target) in enumerate(dataloader): |
|||
bs = target.size(0) |
|||
output, loss = step(model, data, target, criterion, args) |
|||
output = torch.softmax(output, dim=1) # NOTE: check |
|||
losses.update(loss.item(), bs) |
|||
targets.extend(target.cpu().numpy().tolist()) |
|||
probas.extend(output.cpu().numpy().tolist()) |
|||
probas = np.asarray(probas) |
|||
preds = np.argmax(probas, axis=1) |
|||
acc = accuracy_score(targets, preds) |
|||
return { |
|||
'loss': losses.avg, 'acc': acc, |
|||
} |
@ -0,0 +1,48 @@ |
|||
import os |
|||
import shutil |
|||
import torch |
|||
|
|||
|
|||
def save(log_dir, state_dict, is_best): |
|||
checkpoint_path = os.path.join(log_dir, 'checkpoint.pt') |
|||
torch.save(state_dict, checkpoint_path) |
|||
if is_best: |
|||
best_model_path = os.path.join(log_dir, 'best_model.pt') |
|||
shutil.copyfile(checkpoint_path, best_model_path) |
|||
|
|||
|
|||
def get_logger(log_file): |
|||
from logging import getLogger, FileHandler, StreamHandler, Formatter, DEBUG, INFO # noqa |
|||
fh = FileHandler(log_file) |
|||
fh.setLevel(DEBUG) |
|||
sh = StreamHandler() |
|||
sh.setLevel(INFO) |
|||
for handler in [fh, sh]: |
|||
formatter = Formatter('%(asctime)s - %(message)s') |
|||
handler.setFormatter(formatter) |
|||
logger = getLogger('adda') |
|||
logger.setLevel(INFO) |
|||
logger.addHandler(fh) |
|||
logger.addHandler(sh) |
|||
return logger |
|||
|
|||
|
|||
class AverageMeter(object): |
|||
"""Computes and stores the average and current value |
|||
https://github.com/pytorch/examples/blob/master/imagenet/main.py#L296 |
|||
""" |
|||
|
|||
def __init__(self): |
|||
self.reset() |
|||
|
|||
def reset(self): |
|||
self.val = 0 |
|||
self.avg = 0 |
|||
self.sum = 0 |
|||
self.count = 0 |
|||
|
|||
def update(self, val, n=1): |
|||
self.val = val |
|||
self.sum += val * n |
|||
self.count += n |
|||
self.avg = self.sum / self.count |
Loading…
Reference in new issue