From 8cb1886af9efdcb88324b8f236f9f17d8f2e66a3 Mon Sep 17 00:00:00 2001 From: fazilaltinel Date: Fri, 11 Dec 2020 15:53:36 +0300 Subject: [PATCH] ResNet-50 implementation for office dataset --- README.md | 7 ++++ core/experiment_rn50.py | 48 ++++++++++++++++++++++++++++ core/train_source_rn50.py | 56 ++++++++++++++++++++++++++++++++ main.py | 15 ++++++--- models/resnet50off.py | 67 +++++++++++++++++++++++++++++++++++++++ utils/altutils.py | 24 +++++++++++++- 6 files changed, 212 insertions(+), 5 deletions(-) create mode 100644 core/experiment_rn50.py create mode 100644 core/train_source_rn50.py create mode 100644 models/resnet50off.py diff --git a/README.md b/README.md index 00ace4e..49d682c 100644 --- a/README.md +++ b/README.md @@ -11,11 +11,18 @@ Before running the training code, make sure that `DATASETDIR` environment variab - PyTorch 1.6.0 ## Example +For training on SVHN-MNIST ``` $ python train_source.py --logdir outputs $ python main.py --logdir outputs --trained outputs/best_model.pt --slope 0.2 ``` +For training on Office dataset using ResNet-50 +``` +$ python core/train_source_rn50.py --n_classes 31 --logdir outputs +$ python main.py --n_classes 31 --trained outputs/best_model.pt --logdir outputs --model resnet50 --src-cat amazon --tgt-cat webcam +``` + ## Result ### SVHN -> MNIST | | Paper | This Repo | diff --git a/core/experiment_rn50.py b/core/experiment_rn50.py new file mode 100644 index 0000000..c4a57cd --- /dev/null +++ b/core/experiment_rn50.py @@ -0,0 +1,48 @@ +import os +import sys +sys.path.append(os.path.abspath('.')) +import torch +from torch import nn, optim +from torch.utils.data import DataLoader +from torchvision.datasets import SVHN, MNIST +from torchvision import transforms +from models.resnet50off import CNN, Discriminator +from core.trainer import train_target_cnn +from utils.utils import get_logger +from utils.altutils import get_office + + +def run(args): + if not os.path.exists(args.logdir): + os.makedirs(args.logdir) + logger = get_logger(os.path.join(args.logdir, 'main.log')) + logger.info(args) + + # data loaders + dataset_root = os.environ["DATASETDIR"] + source_loader = get_office(dataset_root, args.batch_size, args.src_cat) + target_loader = get_office(dataset_root, args.batch_size, args.tgt_cat) + + # train source CNN + source_cnn = CNN(in_channels=args.in_channels).to(args.device) + if os.path.isfile(args.trained): + c = torch.load(args.trained) + source_cnn.load_state_dict(c['model']) + logger.info('Loaded `{}`'.format(args.trained)) + + # train target CNN + target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device) + target_cnn.load_state_dict(source_cnn.state_dict()) + discriminator = Discriminator(args=args).to(args.device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam( + target_cnn.encoder.parameters(), + lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) + d_optimizer = optim.Adam( + discriminator.parameters(), + lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) + train_target_cnn( + source_cnn, target_cnn, discriminator, + criterion, optimizer, d_optimizer, + source_loader, target_loader, target_loader, + args=args) diff --git a/core/train_source_rn50.py b/core/train_source_rn50.py new file mode 100644 index 0000000..4a21077 --- /dev/null +++ b/core/train_source_rn50.py @@ -0,0 +1,56 @@ +import argparse +import os +import sys +sys.path.append(os.path.abspath('.')) +from torch import nn, optim +from torch.utils.data import DataLoader +from torchvision.datasets import SVHN +from torchvision import transforms +from models.resnet50off import CNN +from core.trainer import train_source_cnn +from utils.utils import get_logger +from utils.altutils import get_office + + +def main(args): + if not os.path.exists(args.logdir): + os.makedirs(args.logdir) + logger = get_logger(os.path.join(args.logdir, 'train_source.log')) + logger.info(args) + + # data loaders + dataset_root = os.environ["DATASETDIR"] + source_loader = get_office(dataset_root, args.batch_size, args.src_cat) + + # train source CNN + source_cnn = CNN(in_channels=args.in_channels).to(args.device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam( + source_cnn.parameters(), + lr=args.lr, weight_decay=args.weight_decay) + source_cnn = train_source_cnn( + source_cnn, source_loader, source_loader, + criterion, optimizer, args=args) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # NN + parser.add_argument('--in_channels', type=int, default=3) + parser.add_argument('--n_classes', type=int, default=10) + parser.add_argument('--trained', type=str, default='') + parser.add_argument('--slope', type=float, default=0.2) + # train + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--weight_decay', type=float, default=2.5e-5) + parser.add_argument('--epochs', type=int, default=50) + parser.add_argument('--batch_size', type=int, default=32) + # misc + parser.add_argument('--device', type=str, default='cuda:0') + parser.add_argument('--n_workers', type=int, default=0) + parser.add_argument('--logdir', type=str, default='outputs/garbage') + parser.add_argument('--message', '-m', type=str, default='') + # office dataset categories + parser.add_argument('--src_cat', type=str, default='amazon') + args, unknown = parser.parse_known_args() + main(args) diff --git a/main.py b/main.py index 19c7715..f2ec273 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,4 @@ import argparse -from core.experiment import run if __name__ == '__main__': @@ -8,17 +7,25 @@ if __name__ == '__main__': parser.add_argument('--in_channels', type=int, default=3) parser.add_argument('--n_classes', type=int, default=10) parser.add_argument('--trained', type=str, default='') - parser.add_argument('--slope', type=float, default=0.1) + parser.add_argument('--slope', type=float, default=0.2) + parser.add_argument('--model', type=str, default='default') # train parser.add_argument('--lr', type=float, default=2e-4) parser.add_argument('--weight_decay', type=float, default=2.5e-5) - parser.add_argument('--epochs', type=int, default=100) - parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--epochs', type=int, default=500) + parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--betas', type=float, nargs='+', default=(.5, .999)) # misc parser.add_argument('--device', type=str, default='cuda:0') parser.add_argument('--n_workers', type=int, default=0) parser.add_argument('--logdir', type=str, default='outputs/garbage') + # office dataset categories + parser.add_argument('--src_cat', type=str, default='amazon') + parser.add_argument('--tgt_cat', type=str, default='webcam') parser.add_argument('--message', '-m', type=str, default='') args, unknown = parser.parse_known_args() + if args.model == 'default': + from core.experiment import run + elif args.model == 'resnet50': + from core.experiment_rn50 import run run(args) diff --git a/models/resnet50off.py b/models/resnet50off.py new file mode 100644 index 0000000..aac8d84 --- /dev/null +++ b/models/resnet50off.py @@ -0,0 +1,67 @@ +from torch import nn +import torch.nn.functional as F +from torchvision import models + + +class Encoder(nn.Module): + def __init__(self, in_channels=3, h=256, dropout=0.5): + super(Encoder, self).__init__() + + resnetModel = models.resnet50(pretrained=True) + feature_map = list(resnetModel.children()) + feature_map.pop() + self.feature_extractor = nn.Sequential(*feature_map) + + def forward(self, x): + x = x.expand(x.data.shape[0], 3, 227, 227) + x = self.feature_extractor(x) + x = x.view(x.size(0), -1) + return x + + +class Classifier(nn.Module): + def __init__(self, n_classes, dropout=0.5): + super(Classifier, self).__init__() + self.l1 = nn.Linear(2048, n_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight) + + def forward(self, x): + x = self.l1(x) + return x + + +class CNN(nn.Module): + def __init__(self, in_channels=3, n_classes=31, target=False): + super(CNN, self).__init__() + self.encoder = Encoder(in_channels=in_channels) + self.classifier = Classifier(n_classes) + if target: + for param in self.classifier.parameters(): + param.requires_grad = False + + def forward(self, x): + x = self.encoder(x) + x = self.classifier(x) + return x + + +class Discriminator(nn.Module): + def __init__(self, h=500, args=None): + super(Discriminator, self).__init__() + self.l1 = nn.Linear(2048, h) + self.l2 = nn.Linear(h, h) + self.l3 = nn.Linear(h, 2) + self.slope = args.slope + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight) + + def forward(self, x): + x = F.leaky_relu(self.l1(x), self.slope) + x = F.leaky_relu(self.l2(x), self.slope) + x = self.l3(x) + return x diff --git a/utils/altutils.py b/utils/altutils.py index 1c22525..ded6777 100644 --- a/utils/altutils.py +++ b/utils/altutils.py @@ -1,3 +1,7 @@ +import os +import torch +from torchvision import datasets, transforms +import torch.utils.data as data import configparser import logging @@ -30,4 +34,22 @@ def setLogger(logFilePath): logHandler = [logging.FileHandler(logFilePath), logging.StreamHandler()] logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", handlers=logHandler) logger = logging.getLogger() - return logger \ No newline at end of file + return logger + +def get_office(dataset_root, batch_size, category): + """Get Office datasets loader.""" + # image pre-processing + pre_process = transforms.Compose([ + transforms.Resize(227), + transforms.ToTensor(), + transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) + ]) + + # datasets and data_loader + office_dataset = datasets.ImageFolder( + os.path.join(dataset_root, 'office31', category, 'images'), transform=pre_process) + + office_dataloader = torch.utils.data.DataLoader( + dataset=office_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True) + + return office_dataloader \ No newline at end of file