diff --git a/experiment.py b/experiment.py new file mode 100644 index 0000000..6e532d8 --- /dev/null +++ b/experiment.py @@ -0,0 +1,80 @@ +import os + +import torch +from torch import nn, optim +from torch.utils.data import DataLoader +from torchvision.datasets import SVHN, MNIST +from torchvision import transforms + +from models import CNN, Discriminator +from trainer import train_source_cnn, train_target_cnn + + +def run(args): + if not os.path.exists(args.logdir): + os.makedirs(args.logdir) + + # data + source_transform = transforms.Compose([ + # transforms.Grayscale(), + transforms.ToTensor()] + ) + target_transform = transforms.Compose([ + transforms.Resize(32), + transforms.ToTensor(), + transforms.Lambda(lambda x: x.repeat(3, 1, 1)) + ]) + source_dataset_train = SVHN( + './input', 'train', transform=source_transform, download=True) + source_dataset_test = SVHN( + './input', 'test', transform=source_transform, download=True) + target_dataset_train = MNIST( + './input', 'train', transform=target_transform, download=True) + target_dataset_test = MNIST( + './input', 'test', transform=target_transform, download=True) + source_train_loader = DataLoader( + source_dataset_train, args.batch_size, shuffle=True, + drop_last=True, + num_workers=args.n_workers) + source_test_loader = DataLoader( + source_dataset_test, args.batch_size, shuffle=False, + num_workers=args.n_workers) + target_train_loader = DataLoader( + target_dataset_train, args.batch_size, shuffle=True, + drop_last=True, + num_workers=args.n_workers) + target_test_loader = DataLoader( + target_dataset_test, args.batch_size, shuffle=False, + num_workers=args.n_workers) + + # train source CNN + source_cnn = CNN(in_channels=args.in_channels).to(args.device) + if os.path.isfile(args.trained): + c = torch.load(args.trained) + source_cnn.load_state_dict(c['model']) + print('Loaded `{}`'.format(args.trained)) + else: + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam( + source_cnn.parameters(), + lr=args.lr, weight_decay=args.weight_decay) + source_cnn = train_source_cnn( + source_cnn, source_train_loader, source_test_loader, + criterion, optimizer, args=args) + + # train target CNN + target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device) + target_cnn.load_state_dict(source_cnn.state_dict()) + discriminator = Discriminator(args=args).to(args.device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.RMSprop( # optim.Adam( + target_cnn.encoder.parameters(), + lr=args.lr, weight_decay=args.weight_decay) + d_optimizer = optim.RMSprop( # optim.Adam( + discriminator.parameters(), + lr=args.lr, weight_decay=args.weight_decay) + train_target_cnn( + source_cnn, target_cnn, discriminator, + criterion, optimizer, d_optimizer, + source_train_loader, target_train_loader, target_test_loader, + args=args) diff --git a/main.py b/main.py new file mode 100644 index 0000000..5a7a38f --- /dev/null +++ b/main.py @@ -0,0 +1,23 @@ +import argparse +import experiment + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # NN + parser.add_argument('--in_channels', type=int, default=1) + parser.add_argument('--n_classes', type=int, default=10) + parser.add_argument('--trained', type=str, default='') + parser.add_argument('--slope', type=float, default=0.2) + # train + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--weight_decay', type=float, default=0.) + parser.add_argument('--epochs', type=int, default=512) + parser.add_argument('--batch_size', type=int, default=256) + # misc + parser.add_argument('--device', type=str, default='cuda:0') + parser.add_argument('--n_workers', type=int, default=0) + parser.add_argument('--logdir', type=str, default='outputs/garbage') + parser.add_argument('--message', '-m', type=str, default='') + args, unknown = parser.parse_known_args() + experiment.run(args) diff --git a/models.py b/models.py new file mode 100644 index 0000000..4defec2 --- /dev/null +++ b/models.py @@ -0,0 +1,77 @@ +from torch import nn +import torch.nn.functional as F + + +class Encoder(nn.Module): + def __init__(self, in_channels=1, h=256, dropout=0.5): + super(Encoder, self).__init__() + self.conv1 = nn.Conv2d(in_channels, 8, kernel_size=5, stride=1) + self.conv2 = nn.Conv2d(8, 16, kernel_size=5, stride=1) + self.conv3 = nn.Conv2d(16, 120, kernel_size=4, stride=1) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.relu = nn.ReLU() + self.dropout1 = nn.Dropout2d(dropout) + self.dropout2 = nn.Dropout(dropout) + self.fc = nn.Linear(480, 500) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight) + + def forward(self, x): + bs = x.size(0) + x = self.pool(self.relu(self.conv1(x))) + x = self.pool(self.relu(self.conv2(x))) + # x = self.dropout1(self.relu(self.conv3(x))) + x = self.relu(self.conv3(x)) + x = x.view(bs, -1) + x = self.dropout2(self.fc(x)) + return x + + +class Classifier(nn.Module): + def __init__(self, n_classes, dropout=0.5): + super(Classifier, self).__init__() + self.l1 = nn.Linear(500, n_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight) + + def forward(self, x): + x = self.l1(x) + return x + + +class CNN(nn.Module): + def __init__(self, in_channels=1, n_classes=10, target=False): + super(CNN, self).__init__() + self.encoder = Encoder(in_channels=in_channels) + self.classifier = Classifier(n_classes) + if target: + for param in self.classifier.parameters(): + param.requires_grad = False + + def forward(self, x): + x = self.encoder(x) + x = self.classifier(x) + return x + + +class Discriminator(nn.Module): + def __init__(self, h=500, args=None): + super(Discriminator, self).__init__() + self.l1 = nn.Linear(500, h) + self.l2 = nn.Linear(h, h) + self.l3 = nn.Linear(h, 2) + self.slope = args.slope + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight) + + def forward(self, x): + x = F.leaky_relu(self.l1(x), self.slope) + x = F.leaky_relu(self.l2(x), self.slope) + x = self.l3(x) + return x diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000..44c88d5 --- /dev/null +++ b/trainer.py @@ -0,0 +1,174 @@ +from time import time + +import numpy as np +from sklearn.metrics import accuracy_score +import torch + +from utils import AverageMeter, save + + +def train_source_cnn( + source_cnn, train_loader, test_loader, criterion, optimizer, + args=None +): + best_score = None + for epoch_i in range(1, 1 + args.epochs): + start_time = time() + training = train( + source_cnn, train_loader, criterion, optimizer, args=args) + validation = validate( + source_cnn, test_loader, criterion, args=args) + log = 'Epoch {}/{} '.format(epoch_i, args.epochs) + log += 'Train/Loss {:.3f} Train/Acc {:.3f} '.format( + training['loss'], training['acc']) + log += 'Val/Loss {:.3f} Val/Acc {:.3f} '.format( + validation['loss'], validation['acc']) + log += 'Time {:.2f}s'.format(time() - start_time) + print(log) + + # save + is_best = (best_score is None or validation['acc'] > best_score) + best_score = validation['acc'] if is_best else best_score + state_dict = { + 'model': source_cnn.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch_i, + 'val/acc': best_score, + } + save(args.logdir, state_dict, is_best) + + return source_cnn + + +def train_target_cnn( + source_cnn, target_cnn, discriminator, + criterion, optimizer, d_optimizer, + source_train_loader, target_train_loader, target_test_loader, + args=None +): + validation = validate(source_cnn, target_test_loader, criterion, args=args) + log_source = 'Source/Val/Acc {:.3f} '.format(validation['acc']) + + # best_score = None + for epoch_i in range(1, 1 + args.epochs): + start_time = time() + training = adversarial( + source_cnn, target_cnn, discriminator, + source_train_loader, target_train_loader, + criterion, criterion, + optimizer, d_optimizer, + args=args + ) + validation = validate( + target_cnn, target_test_loader, criterion, args=args) + log = 'Epoch {}/{} '.format(epoch_i, args.epochs) + log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format( + training['d/loss'], training['target/loss']) + log += 'Target/Val/Loss {:.3f} Target/Val/Acc {:.3f} '.format( + validation['loss'], validation['acc']) + log += log_source + log += 'Time {:.2f}s'.format(time() - start_time) + print(log) + + +def adversarial( + source_cnn, target_cnn, discriminator, + source_loader, target_loader, + criterion, d_criterion, + optimizer, d_optimizer, + args=None +): + source_cnn.eval() + target_cnn.train() + discriminator.train() + + losses, d_losses = AverageMeter(), AverageMeter() + n_iters = min(len(source_loader), len(target_loader)) + source_iter, target_iter = iter(source_loader), iter(target_loader) + for iter_i in range(n_iters): + source_data, source_target = source_iter.next() + target_data, target_target = target_iter.next() + source_data = source_data.to(args.device) + target_data = target_data.to(args.device) + bs = source_data.size(0) + + D_input_source = source_cnn.encoder(source_data) + D_input_target = target_cnn.encoder(target_data) + D_target_source = torch.tensor( + [0] * bs, dtype=torch.long).to(args.device) + D_target_target = torch.tensor( + [1] * bs, dtype=torch.long).to(args.device) + + # train Discriminator + D_output_source = discriminator(D_input_source) + D_output_target = discriminator(D_input_target) + d_loss_source = d_criterion(D_output_source, D_target_source) + d_loss_target = d_criterion(D_output_target, D_target_target) + d_loss = 0.5 * (d_loss_source + d_loss_target) + d_optimizer.zero_grad() + d_loss.backward(retain_graph=True) + d_optimizer.step() + d_losses.update(d_loss.item(), bs) + + # train Target + ''' + D_input_target = target_cnn.encoder(target_data) + D_output_target = discriminator(D_input_target) + ''' + loss = criterion(D_output_target, D_target_source) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses.update(loss.item(), bs) + return {'d/loss': d_losses.avg, 'target/loss': losses.avg} + + +def step(model, data, target, criterion, args): + data, target = data.to(args.device), target.to(args.device) + output = model(data) + loss = criterion(output, target) + return output, loss + + +def train(model, dataloader, criterion, optimizer, args=None): + model.train() + losses = AverageMeter() + targets, probas = [], [] + for i, (data, target) in enumerate(dataloader): + bs = target.size(0) + output, loss = step(model, data, target, criterion, args) + output = torch.softmax(output, dim=1) # NOTE + losses.update(loss.item(), bs) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + targets.extend(target.cpu().detach().numpy().tolist()) + probas.extend(output.cpu().detach().numpy().tolist()) + probas = np.asarray(probas) + preds = np.argmax(probas, axis=1) + acc = accuracy_score(targets, preds) + return { + 'loss': losses.avg, 'acc': acc, + } + + +def validate(model, dataloader, criterion, args=None): + model.eval() + losses = AverageMeter() + targets, probas = [], [] + with torch.no_grad(): + for iter_i, (data, target) in enumerate(dataloader): + bs = target.size(0) + output, loss = step(model, data, target, criterion, args) + output = torch.softmax(output, dim=1) # NOTE: check + losses.update(loss.item(), bs) + targets.extend(target.cpu().numpy().tolist()) + probas.extend(output.cpu().numpy().tolist()) + probas = np.asarray(probas) + preds = np.argmax(probas, axis=1) + acc = accuracy_score(targets, preds) + return { + 'loss': losses.avg, 'acc': acc, + } diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..6b95707 --- /dev/null +++ b/utils.py @@ -0,0 +1,32 @@ +import os +import shutil +import torch + + +def save(log_dir, state_dict, is_best): + checkpoint_path = os.path.join(log_dir, 'checkpoint.pt') + torch.save(state_dict, checkpoint_path) + if is_best: + best_model_path = os.path.join(log_dir, 'best_model.pt') + shutil.copyfile(checkpoint_path, best_model_path) + + +class AverageMeter(object): + """Computes and stores the average and current value + https://github.com/pytorch/examples/blob/master/imagenet/main.py#L296 + """ + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count