Browse Source

ResNet-50 implementation for office dataset

master
Fazil Altinel 4 years ago
parent
commit
8cb1886af9
  1. 7
      README.md
  2. 48
      core/experiment_rn50.py
  3. 56
      core/train_source_rn50.py
  4. 15
      main.py
  5. 67
      models/resnet50off.py
  6. 24
      utils/altutils.py

7
README.md

@ -11,11 +11,18 @@ Before running the training code, make sure that `DATASETDIR` environment variab
- PyTorch 1.6.0 - PyTorch 1.6.0
## Example ## Example
For training on SVHN-MNIST
``` ```
$ python train_source.py --logdir outputs $ python train_source.py --logdir outputs
$ python main.py --logdir outputs --trained outputs/best_model.pt --slope 0.2 $ python main.py --logdir outputs --trained outputs/best_model.pt --slope 0.2
``` ```
For training on Office dataset using ResNet-50
```
$ python core/train_source_rn50.py --n_classes 31 --logdir outputs
$ python main.py --n_classes 31 --trained outputs/best_model.pt --logdir outputs --model resnet50 --src-cat amazon --tgt-cat webcam
```
## Result ## Result
### SVHN -> MNIST ### SVHN -> MNIST
| | Paper | This Repo | | | Paper | This Repo |

48
core/experiment_rn50.py

@ -0,0 +1,48 @@
import os
import sys
sys.path.append(os.path.abspath('.'))
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import SVHN, MNIST
from torchvision import transforms
from models.resnet50off import CNN, Discriminator
from core.trainer import train_target_cnn
from utils.utils import get_logger
from utils.altutils import get_office
def run(args):
if not os.path.exists(args.logdir):
os.makedirs(args.logdir)
logger = get_logger(os.path.join(args.logdir, 'main.log'))
logger.info(args)
# data loaders
dataset_root = os.environ["DATASETDIR"]
source_loader = get_office(dataset_root, args.batch_size, args.src_cat)
target_loader = get_office(dataset_root, args.batch_size, args.tgt_cat)
# train source CNN
source_cnn = CNN(in_channels=args.in_channels).to(args.device)
if os.path.isfile(args.trained):
c = torch.load(args.trained)
source_cnn.load_state_dict(c['model'])
logger.info('Loaded `{}`'.format(args.trained))
# train target CNN
target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device)
target_cnn.load_state_dict(source_cnn.state_dict())
discriminator = Discriminator(args=args).to(args.device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
target_cnn.encoder.parameters(),
lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
d_optimizer = optim.Adam(
discriminator.parameters(),
lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
train_target_cnn(
source_cnn, target_cnn, discriminator,
criterion, optimizer, d_optimizer,
source_loader, target_loader, target_loader,
args=args)

56
core/train_source_rn50.py

@ -0,0 +1,56 @@
import argparse
import os
import sys
sys.path.append(os.path.abspath('.'))
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import SVHN
from torchvision import transforms
from models.resnet50off import CNN
from core.trainer import train_source_cnn
from utils.utils import get_logger
from utils.altutils import get_office
def main(args):
if not os.path.exists(args.logdir):
os.makedirs(args.logdir)
logger = get_logger(os.path.join(args.logdir, 'train_source.log'))
logger.info(args)
# data loaders
dataset_root = os.environ["DATASETDIR"]
source_loader = get_office(dataset_root, args.batch_size, args.src_cat)
# train source CNN
source_cnn = CNN(in_channels=args.in_channels).to(args.device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
source_cnn.parameters(),
lr=args.lr, weight_decay=args.weight_decay)
source_cnn = train_source_cnn(
source_cnn, source_loader, source_loader,
criterion, optimizer, args=args)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# NN
parser.add_argument('--in_channels', type=int, default=3)
parser.add_argument('--n_classes', type=int, default=10)
parser.add_argument('--trained', type=str, default='')
parser.add_argument('--slope', type=float, default=0.2)
# train
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=2.5e-5)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=32)
# misc
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--n_workers', type=int, default=0)
parser.add_argument('--logdir', type=str, default='outputs/garbage')
parser.add_argument('--message', '-m', type=str, default='')
# office dataset categories
parser.add_argument('--src_cat', type=str, default='amazon')
args, unknown = parser.parse_known_args()
main(args)

15
main.py

@ -1,5 +1,4 @@
import argparse import argparse
from core.experiment import run
if __name__ == '__main__': if __name__ == '__main__':
@ -8,17 +7,25 @@ if __name__ == '__main__':
parser.add_argument('--in_channels', type=int, default=3) parser.add_argument('--in_channels', type=int, default=3)
parser.add_argument('--n_classes', type=int, default=10) parser.add_argument('--n_classes', type=int, default=10)
parser.add_argument('--trained', type=str, default='') parser.add_argument('--trained', type=str, default='')
parser.add_argument('--slope', type=float, default=0.1)
parser.add_argument('--slope', type=float, default=0.2)
parser.add_argument('--model', type=str, default='default')
# train # train
parser.add_argument('--lr', type=float, default=2e-4) parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--weight_decay', type=float, default=2.5e-5) parser.add_argument('--weight_decay', type=float, default=2.5e-5)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--betas', type=float, nargs='+', default=(.5, .999)) parser.add_argument('--betas', type=float, nargs='+', default=(.5, .999))
# misc # misc
parser.add_argument('--device', type=str, default='cuda:0') parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--n_workers', type=int, default=0) parser.add_argument('--n_workers', type=int, default=0)
parser.add_argument('--logdir', type=str, default='outputs/garbage') parser.add_argument('--logdir', type=str, default='outputs/garbage')
# office dataset categories
parser.add_argument('--src_cat', type=str, default='amazon')
parser.add_argument('--tgt_cat', type=str, default='webcam')
parser.add_argument('--message', '-m', type=str, default='') parser.add_argument('--message', '-m', type=str, default='')
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
if args.model == 'default':
from core.experiment import run
elif args.model == 'resnet50':
from core.experiment_rn50 import run
run(args) run(args)

67
models/resnet50off.py

@ -0,0 +1,67 @@
from torch import nn
import torch.nn.functional as F
from torchvision import models
class Encoder(nn.Module):
def __init__(self, in_channels=3, h=256, dropout=0.5):
super(Encoder, self).__init__()
resnetModel = models.resnet50(pretrained=True)
feature_map = list(resnetModel.children())
feature_map.pop()
self.feature_extractor = nn.Sequential(*feature_map)
def forward(self, x):
x = x.expand(x.data.shape[0], 3, 227, 227)
x = self.feature_extractor(x)
x = x.view(x.size(0), -1)
return x
class Classifier(nn.Module):
def __init__(self, n_classes, dropout=0.5):
super(Classifier, self).__init__()
self.l1 = nn.Linear(2048, n_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
def forward(self, x):
x = self.l1(x)
return x
class CNN(nn.Module):
def __init__(self, in_channels=3, n_classes=31, target=False):
super(CNN, self).__init__()
self.encoder = Encoder(in_channels=in_channels)
self.classifier = Classifier(n_classes)
if target:
for param in self.classifier.parameters():
param.requires_grad = False
def forward(self, x):
x = self.encoder(x)
x = self.classifier(x)
return x
class Discriminator(nn.Module):
def __init__(self, h=500, args=None):
super(Discriminator, self).__init__()
self.l1 = nn.Linear(2048, h)
self.l2 = nn.Linear(h, h)
self.l3 = nn.Linear(h, 2)
self.slope = args.slope
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
def forward(self, x):
x = F.leaky_relu(self.l1(x), self.slope)
x = F.leaky_relu(self.l2(x), self.slope)
x = self.l3(x)
return x

24
utils/altutils.py

@ -1,3 +1,7 @@
import os
import torch
from torchvision import datasets, transforms
import torch.utils.data as data
import configparser import configparser
import logging import logging
@ -30,4 +34,22 @@ def setLogger(logFilePath):
logHandler = [logging.FileHandler(logFilePath), logging.StreamHandler()] logHandler = [logging.FileHandler(logFilePath), logging.StreamHandler()]
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", handlers=logHandler) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", handlers=logHandler)
logger = logging.getLogger() logger = logging.getLogger()
return logger
return logger
def get_office(dataset_root, batch_size, category):
"""Get Office datasets loader."""
# image pre-processing
pre_process = transforms.Compose([
transforms.Resize(227),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
# datasets and data_loader
office_dataset = datasets.ImageFolder(
os.path.join(dataset_root, 'office31', category, 'images'), transform=pre_process)
office_dataloader = torch.utils.data.DataLoader(
dataset=office_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
return office_dataloader
Loading…
Cancel
Save