Fazil Altinel
4 years ago
6 changed files with 212 additions and 5 deletions
@ -0,0 +1,48 @@ |
|||
import os |
|||
import sys |
|||
sys.path.append(os.path.abspath('.')) |
|||
import torch |
|||
from torch import nn, optim |
|||
from torch.utils.data import DataLoader |
|||
from torchvision.datasets import SVHN, MNIST |
|||
from torchvision import transforms |
|||
from models.resnet50off import CNN, Discriminator |
|||
from core.trainer import train_target_cnn |
|||
from utils.utils import get_logger |
|||
from utils.altutils import get_office |
|||
|
|||
|
|||
def run(args): |
|||
if not os.path.exists(args.logdir): |
|||
os.makedirs(args.logdir) |
|||
logger = get_logger(os.path.join(args.logdir, 'main.log')) |
|||
logger.info(args) |
|||
|
|||
# data loaders |
|||
dataset_root = os.environ["DATASETDIR"] |
|||
source_loader = get_office(dataset_root, args.batch_size, args.src_cat) |
|||
target_loader = get_office(dataset_root, args.batch_size, args.tgt_cat) |
|||
|
|||
# train source CNN |
|||
source_cnn = CNN(in_channels=args.in_channels).to(args.device) |
|||
if os.path.isfile(args.trained): |
|||
c = torch.load(args.trained) |
|||
source_cnn.load_state_dict(c['model']) |
|||
logger.info('Loaded `{}`'.format(args.trained)) |
|||
|
|||
# train target CNN |
|||
target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device) |
|||
target_cnn.load_state_dict(source_cnn.state_dict()) |
|||
discriminator = Discriminator(args=args).to(args.device) |
|||
criterion = nn.CrossEntropyLoss() |
|||
optimizer = optim.Adam( |
|||
target_cnn.encoder.parameters(), |
|||
lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) |
|||
d_optimizer = optim.Adam( |
|||
discriminator.parameters(), |
|||
lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) |
|||
train_target_cnn( |
|||
source_cnn, target_cnn, discriminator, |
|||
criterion, optimizer, d_optimizer, |
|||
source_loader, target_loader, target_loader, |
|||
args=args) |
@ -0,0 +1,56 @@ |
|||
import argparse |
|||
import os |
|||
import sys |
|||
sys.path.append(os.path.abspath('.')) |
|||
from torch import nn, optim |
|||
from torch.utils.data import DataLoader |
|||
from torchvision.datasets import SVHN |
|||
from torchvision import transforms |
|||
from models.resnet50off import CNN |
|||
from core.trainer import train_source_cnn |
|||
from utils.utils import get_logger |
|||
from utils.altutils import get_office |
|||
|
|||
|
|||
def main(args): |
|||
if not os.path.exists(args.logdir): |
|||
os.makedirs(args.logdir) |
|||
logger = get_logger(os.path.join(args.logdir, 'train_source.log')) |
|||
logger.info(args) |
|||
|
|||
# data loaders |
|||
dataset_root = os.environ["DATASETDIR"] |
|||
source_loader = get_office(dataset_root, args.batch_size, args.src_cat) |
|||
|
|||
# train source CNN |
|||
source_cnn = CNN(in_channels=args.in_channels).to(args.device) |
|||
criterion = nn.CrossEntropyLoss() |
|||
optimizer = optim.Adam( |
|||
source_cnn.parameters(), |
|||
lr=args.lr, weight_decay=args.weight_decay) |
|||
source_cnn = train_source_cnn( |
|||
source_cnn, source_loader, source_loader, |
|||
criterion, optimizer, args=args) |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
parser = argparse.ArgumentParser() |
|||
# NN |
|||
parser.add_argument('--in_channels', type=int, default=3) |
|||
parser.add_argument('--n_classes', type=int, default=10) |
|||
parser.add_argument('--trained', type=str, default='') |
|||
parser.add_argument('--slope', type=float, default=0.2) |
|||
# train |
|||
parser.add_argument('--lr', type=float, default=1e-3) |
|||
parser.add_argument('--weight_decay', type=float, default=2.5e-5) |
|||
parser.add_argument('--epochs', type=int, default=50) |
|||
parser.add_argument('--batch_size', type=int, default=32) |
|||
# misc |
|||
parser.add_argument('--device', type=str, default='cuda:0') |
|||
parser.add_argument('--n_workers', type=int, default=0) |
|||
parser.add_argument('--logdir', type=str, default='outputs/garbage') |
|||
parser.add_argument('--message', '-m', type=str, default='') |
|||
# office dataset categories |
|||
parser.add_argument('--src_cat', type=str, default='amazon') |
|||
args, unknown = parser.parse_known_args() |
|||
main(args) |
@ -0,0 +1,67 @@ |
|||
from torch import nn |
|||
import torch.nn.functional as F |
|||
from torchvision import models |
|||
|
|||
|
|||
class Encoder(nn.Module): |
|||
def __init__(self, in_channels=3, h=256, dropout=0.5): |
|||
super(Encoder, self).__init__() |
|||
|
|||
resnetModel = models.resnet50(pretrained=True) |
|||
feature_map = list(resnetModel.children()) |
|||
feature_map.pop() |
|||
self.feature_extractor = nn.Sequential(*feature_map) |
|||
|
|||
def forward(self, x): |
|||
x = x.expand(x.data.shape[0], 3, 227, 227) |
|||
x = self.feature_extractor(x) |
|||
x = x.view(x.size(0), -1) |
|||
return x |
|||
|
|||
|
|||
class Classifier(nn.Module): |
|||
def __init__(self, n_classes, dropout=0.5): |
|||
super(Classifier, self).__init__() |
|||
self.l1 = nn.Linear(2048, n_classes) |
|||
|
|||
for m in self.modules(): |
|||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
|||
nn.init.kaiming_normal_(m.weight) |
|||
|
|||
def forward(self, x): |
|||
x = self.l1(x) |
|||
return x |
|||
|
|||
|
|||
class CNN(nn.Module): |
|||
def __init__(self, in_channels=3, n_classes=31, target=False): |
|||
super(CNN, self).__init__() |
|||
self.encoder = Encoder(in_channels=in_channels) |
|||
self.classifier = Classifier(n_classes) |
|||
if target: |
|||
for param in self.classifier.parameters(): |
|||
param.requires_grad = False |
|||
|
|||
def forward(self, x): |
|||
x = self.encoder(x) |
|||
x = self.classifier(x) |
|||
return x |
|||
|
|||
|
|||
class Discriminator(nn.Module): |
|||
def __init__(self, h=500, args=None): |
|||
super(Discriminator, self).__init__() |
|||
self.l1 = nn.Linear(2048, h) |
|||
self.l2 = nn.Linear(h, h) |
|||
self.l3 = nn.Linear(h, 2) |
|||
self.slope = args.slope |
|||
|
|||
for m in self.modules(): |
|||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
|||
nn.init.kaiming_normal_(m.weight) |
|||
|
|||
def forward(self, x): |
|||
x = F.leaky_relu(self.l1(x), self.slope) |
|||
x = F.leaky_relu(self.l2(x), self.slope) |
|||
x = self.l3(x) |
|||
return x |
Loading…
Reference in new issue