diff --git a/README.md b/README.md index f6ea012..b6aa876 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,23 @@ # ADDA.PyTorch implement Adversarial Discriminative Domain Adapation in PyTorch + + +## Example +``` +$ python train_source.py --logdir outputs +$ python main.py --logdir outputs --trained outputs/best_model.pt --slope 0.2 +``` + + +## Result +### SVHN -> MNIST +| | Paper | This Repro | +| --- | --- | --- | +| Source only | 0.601 | 0.659 | +| ADDA | 0.760 | ~0.83 | + +![adversarial](adversarial.png) +![target_domain](target_domain.png) + +## Resource +- https://arxiv.org/pdf/1702.05464.pdf diff --git a/adversarial.png b/adversarial.png new file mode 100644 index 0000000..279c0f8 Binary files /dev/null and b/adversarial.png differ diff --git a/experiment.py b/experiment.py new file mode 100644 index 0000000..a4ecff9 --- /dev/null +++ b/experiment.py @@ -0,0 +1,70 @@ +import os + +import torch +from torch import nn, optim +from torch.utils.data import DataLoader +from torchvision.datasets import SVHN, MNIST +from torchvision import transforms + +from models import CNN, Discriminator +from trainer import train_target_cnn +from utils import get_logger + + +def run(args): + if not os.path.exists(args.logdir): + os.makedirs(args.logdir) + logger = get_logger(os.path.join(args.logdir, 'main.log')) + logger.info(args) + + # data + source_transform = transforms.Compose([ + # transforms.Grayscale(), + transforms.ToTensor()] + ) + target_transform = transforms.Compose([ + transforms.Resize(32), + transforms.ToTensor(), + transforms.Lambda(lambda x: x.repeat(3, 1, 1)) + ]) + source_dataset_train = SVHN( + './input', 'train', transform=source_transform, download=True) + target_dataset_train = MNIST( + './input', 'train', transform=target_transform, download=True) + target_dataset_test = MNIST( + './input', 'test', transform=target_transform, download=True) + source_train_loader = DataLoader( + source_dataset_train, args.batch_size, shuffle=True, + drop_last=True, + num_workers=args.n_workers) + target_train_loader = DataLoader( + target_dataset_train, args.batch_size, shuffle=True, + drop_last=True, + num_workers=args.n_workers) + target_test_loader = DataLoader( + target_dataset_test, args.batch_size, shuffle=False, + num_workers=args.n_workers) + + # train source CNN + source_cnn = CNN(in_channels=args.in_channels).to(args.device) + if os.path.isfile(args.trained): + c = torch.load(args.trained) + source_cnn.load_state_dict(c['model']) + logger.info('Loaded `{}`'.format(args.trained)) + + # train target CNN + target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device) + target_cnn.load_state_dict(source_cnn.state_dict()) + discriminator = Discriminator(args=args).to(args.device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam( + target_cnn.encoder.parameters(), + lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) + d_optimizer = optim.Adam( + discriminator.parameters(), + lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) + train_target_cnn( + source_cnn, target_cnn, discriminator, + criterion, optimizer, d_optimizer, + source_train_loader, target_train_loader, target_test_loader, + args=args) diff --git a/main.py b/main.py new file mode 100644 index 0000000..c47af4f --- /dev/null +++ b/main.py @@ -0,0 +1,24 @@ +import argparse +import experiment + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # NN + parser.add_argument('--in_channels', type=int, default=3) + parser.add_argument('--n_classes', type=int, default=10) + parser.add_argument('--trained', type=str, default='') + parser.add_argument('--slope', type=float, default=0.1) + # train + parser.add_argument('--lr', type=float, default=2e-4) + parser.add_argument('--weight_decay', type=float, default=2.5e-5) + parser.add_argument('--epochs', type=int, default=512) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--betas', type=float, nargs='+', default=(.5, .999)) + # misc + parser.add_argument('--device', type=str, default='cuda:0') + parser.add_argument('--n_workers', type=int, default=0) + parser.add_argument('--logdir', type=str, default='outputs/garbage') + parser.add_argument('--message', '-m', type=str, default='') + args, unknown = parser.parse_known_args() + experiment.run(args) diff --git a/models.py b/models.py new file mode 100644 index 0000000..e6a8fa5 --- /dev/null +++ b/models.py @@ -0,0 +1,80 @@ +from torch import nn +import torch.nn.functional as F + + +class Encoder(nn.Module): + def __init__(self, in_channels=1, h=256, dropout=0.5): + super(Encoder, self).__init__() + self.conv1 = nn.Conv2d(in_channels, 20, kernel_size=5, stride=1) + self.conv2 = nn.Conv2d(20, 50, kernel_size=5, stride=1) + self.bn1 = nn.BatchNorm2d(20) + self.bn2 = nn.BatchNorm2d(50) + # self.conv3 = nn.Conv2d(16, 120, kernel_size=4, stride=1) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu = nn.ReLU() + # self.dropout1 = nn.Dropout2d(dropout) + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(1250, 500) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight) + + def forward(self, x): + bs = x.size(0) + x = self.pool(self.relu(self.bn1(self.conv1(x)))) + x = self.pool(self.relu(self.bn2(self.conv2(x)))) + # x = self.dropout1(self.relu(self.conv3(x))) + # x = self.relu(self.conv3(x)) + x = x.view(bs, -1) + x = self.dropout(x) + x = self.fc(x) + return x + + +class Classifier(nn.Module): + def __init__(self, n_classes, dropout=0.5): + super(Classifier, self).__init__() + self.l1 = nn.Linear(500, n_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight) + + def forward(self, x): + x = self.l1(x) + return x + + +class CNN(nn.Module): + def __init__(self, in_channels=1, n_classes=10, target=False): + super(CNN, self).__init__() + self.encoder = Encoder(in_channels=in_channels) + self.classifier = Classifier(n_classes) + if target: + for param in self.classifier.parameters(): + param.requires_grad = False + + def forward(self, x): + x = self.encoder(x) + x = self.classifier(x) + return x + + +class Discriminator(nn.Module): + def __init__(self, h=500, args=None): + super(Discriminator, self).__init__() + self.l1 = nn.Linear(500, h) + self.l2 = nn.Linear(h, h) + self.l3 = nn.Linear(h, 2) + self.slope = args.slope + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight) + + def forward(self, x): + x = F.leaky_relu(self.l1(x), self.slope) + x = F.leaky_relu(self.l2(x), self.slope) + x = self.l3(x) + return x diff --git a/target_domain.png b/target_domain.png new file mode 100644 index 0000000..fb7705d Binary files /dev/null and b/target_domain.png differ diff --git a/train_source.py b/train_source.py new file mode 100644 index 0000000..0481722 --- /dev/null +++ b/train_source.py @@ -0,0 +1,65 @@ +import argparse +import os + +from torch import nn, optim +from torch.utils.data import DataLoader +from torchvision.datasets import SVHN +from torchvision import transforms + +from models import CNN +from trainer import train_source_cnn +from utils import get_logger + + +def main(args): + if not os.path.exists(args.logdir): + os.makedirs(args.logdir) + logger = get_logger(os.path.join(args.logdir, 'train_source.log')) + logger.info(args) + + # data + source_transform = transforms.Compose([ + transforms.ToTensor()] + ) + source_dataset_train = SVHN( + './input', 'train', transform=source_transform, download=True) + source_dataset_test = SVHN( + './input', 'test', transform=source_transform, download=True) + source_train_loader = DataLoader( + source_dataset_train, args.batch_size, shuffle=True, + drop_last=True, + num_workers=args.n_workers) + source_test_loader = DataLoader( + source_dataset_test, args.batch_size, shuffle=False, + num_workers=args.n_workers) + + # train source CNN + source_cnn = CNN(in_channels=args.in_channels).to(args.device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam( + source_cnn.parameters(), + lr=args.lr, weight_decay=args.weight_decay) + source_cnn = train_source_cnn( + source_cnn, source_train_loader, source_test_loader, + criterion, optimizer, args=args) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # NN + parser.add_argument('--in_channels', type=int, default=3) + parser.add_argument('--n_classes', type=int, default=10) + parser.add_argument('--trained', type=str, default='') + parser.add_argument('--slope', type=float, default=0.2) + # train + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--weight_decay', type=float, default=2.5e-5) + parser.add_argument('--epochs', type=int, default=50) + parser.add_argument('--batch_size', type=int, default=128) + # misc + parser.add_argument('--device', type=str, default='cuda:0') + parser.add_argument('--n_workers', type=int, default=0) + parser.add_argument('--logdir', type=str, default='outputs/garbage') + parser.add_argument('--message', '-m', type=str, default='') + args, unknown = parser.parse_known_args() + main(args) diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000..309de24 --- /dev/null +++ b/trainer.py @@ -0,0 +1,200 @@ +from logging import getLogger +from time import time + +import numpy as np +from sklearn.metrics import accuracy_score +from tensorboardX import SummaryWriter +import torch + +from utils import AverageMeter, save + + +logger = getLogger('adda.trainer') + + +def train_source_cnn( + source_cnn, train_loader, test_loader, criterion, optimizer, + args=None +): + best_score = None + for epoch_i in range(1, 1 + args.epochs): + start_time = time() + training = train( + source_cnn, train_loader, criterion, optimizer, args=args) + validation = validate( + source_cnn, test_loader, criterion, args=args) + log = 'Epoch {}/{} '.format(epoch_i, args.epochs) + log += '| Train/Loss {:.3f} Acc {:.3f} '.format( + training['loss'], training['acc']) + log += '| Val/Loss {:.3f} Acc {:.3f} '.format( + validation['loss'], validation['acc']) + log += 'Time {:.2f}s'.format(time() - start_time) + logger.info(log) + + # save + is_best = (best_score is None or validation['acc'] > best_score) + best_score = validation['acc'] if is_best else best_score + state_dict = { + 'model': source_cnn.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch_i, + 'val/acc': best_score, + } + save(args.logdir, state_dict, is_best) + + return source_cnn + + +def train_target_cnn( + source_cnn, target_cnn, discriminator, + criterion, optimizer, d_optimizer, + source_train_loader, target_train_loader, target_test_loader, + args=None +): + validation = validate(source_cnn, target_test_loader, criterion, args=args) + log_source = 'Source/Acc {:.3f} '.format(validation['acc']) + + writer = SummaryWriter(args.logdir) + best_score = None + for epoch_i in range(1, 1 + args.epochs): + start_time = time() + training = adversarial( + source_cnn, target_cnn, discriminator, + source_train_loader, target_train_loader, + criterion, criterion, + optimizer, d_optimizer, + args=args + ) + validation = validate( + target_cnn, target_test_loader, criterion, args=args) + validation2 = validate( + target_cnn, target_train_loader, criterion, args=args) + log = 'Epoch {}/{} '.format(epoch_i, args.epochs) + log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format( + training['d/loss'], training['target/loss']) + log += '[Val] Target/Loss {:.3f} Target/Acc {:.3f} '.format( + validation['loss'], validation['acc']) + log += log_source + log += '[Train] Target/Loss {:.3f} Target/Acc {:.3f} '.format( + validation2['loss'], validation2['acc']) + log += 'Time {:.2f}s'.format(time() - start_time) + logger.info(log) + + # save + is_best = (best_score is None or validation['acc'] > best_score) + best_score = validation['acc'] if is_best else best_score + state_dict = { + 'model': target_cnn.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch_i, + 'val/acc': best_score, + } + save(args.logdir, state_dict, is_best) + + # tensorboard + writer.add_scalar('Adv/D/Loss', training['d/loss'], epoch_i) + writer.add_scalar('Adv/Target/Loss', training['target/loss'], epoch_i) + writer.add_scalar('Val/Target/Loss', validation['loss'], epoch_i) + writer.add_scalar('Val/Target/Acc', validation['acc'], epoch_i) + writer.add_scalar('Train/Target/Acc', validation2['acc'], epoch_i) + + +def adversarial( + source_cnn, target_cnn, discriminator, + source_loader, target_loader, + criterion, d_criterion, + optimizer, d_optimizer, + args=None +): + source_cnn.eval() + target_cnn.encoder.train() + discriminator.train() + + losses, d_losses = AverageMeter(), AverageMeter() + n_iters = min(len(source_loader), len(target_loader)) + source_iter, target_iter = iter(source_loader), iter(target_loader) + for iter_i in range(n_iters): + source_data, source_target = source_iter.next() + target_data, target_target = target_iter.next() + source_data = source_data.to(args.device) + target_data = target_data.to(args.device) + bs = source_data.size(0) + + D_input_source = source_cnn.encoder(source_data) + D_input_target = target_cnn.encoder(target_data) + D_target_source = torch.tensor( + [0] * bs, dtype=torch.long).to(args.device) + D_target_target = torch.tensor( + [1] * bs, dtype=torch.long).to(args.device) + + # train Discriminator + D_output_source = discriminator(D_input_source) + D_output_target = discriminator(D_input_target) + D_output = torch.cat([D_output_source, D_output_target], dim=0) + D_target = torch.cat([D_target_source, D_target_target], dim=0) + d_loss = criterion(D_output, D_target) + d_optimizer.zero_grad() + d_loss.backward() + d_optimizer.step() + d_losses.update(d_loss.item(), bs) + + # train Target + D_input_target = target_cnn.encoder(target_data) + D_output_target = discriminator(D_input_target) + loss = criterion(D_output_target, D_target_source) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses.update(loss.item(), bs) + return {'d/loss': d_losses.avg, 'target/loss': losses.avg} + + +def step(model, data, target, criterion, args): + data, target = data.to(args.device), target.to(args.device) + output = model(data) + loss = criterion(output, target) + return output, loss + + +def train(model, dataloader, criterion, optimizer, args=None): + model.train() + losses = AverageMeter() + targets, probas = [], [] + for i, (data, target) in enumerate(dataloader): + bs = target.size(0) + output, loss = step(model, data, target, criterion, args) + output = torch.softmax(output, dim=1) # NOTE + losses.update(loss.item(), bs) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + targets.extend(target.cpu().detach().numpy().tolist()) + probas.extend(output.cpu().detach().numpy().tolist()) + probas = np.asarray(probas) + preds = np.argmax(probas, axis=1) + acc = accuracy_score(targets, preds) + return { + 'loss': losses.avg, 'acc': acc, + } + + +def validate(model, dataloader, criterion, args=None): + model.eval() + losses = AverageMeter() + targets, probas = [], [] + with torch.no_grad(): + for iter_i, (data, target) in enumerate(dataloader): + bs = target.size(0) + output, loss = step(model, data, target, criterion, args) + output = torch.softmax(output, dim=1) # NOTE: check + losses.update(loss.item(), bs) + targets.extend(target.cpu().numpy().tolist()) + probas.extend(output.cpu().numpy().tolist()) + probas = np.asarray(probas) + preds = np.argmax(probas, axis=1) + acc = accuracy_score(targets, preds) + return { + 'loss': losses.avg, 'acc': acc, + } diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..ae51fc9 --- /dev/null +++ b/utils.py @@ -0,0 +1,48 @@ +import os +import shutil +import torch + + +def save(log_dir, state_dict, is_best): + checkpoint_path = os.path.join(log_dir, 'checkpoint.pt') + torch.save(state_dict, checkpoint_path) + if is_best: + best_model_path = os.path.join(log_dir, 'best_model.pt') + shutil.copyfile(checkpoint_path, best_model_path) + + +def get_logger(log_file): + from logging import getLogger, FileHandler, StreamHandler, Formatter, DEBUG, INFO # noqa + fh = FileHandler(log_file) + fh.setLevel(DEBUG) + sh = StreamHandler() + sh.setLevel(INFO) + for handler in [fh, sh]: + formatter = Formatter('%(asctime)s - %(message)s') + handler.setFormatter(formatter) + logger = getLogger('adda') + logger.setLevel(INFO) + logger.addHandler(fh) + logger.addHandler(sh) + return logger + + +class AverageMeter(object): + """Computes and stores the average and current value + https://github.com/pytorch/examples/blob/master/imagenet/main.py#L296 + """ + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count