Implementation of "Adversarial Discriminative Domain Adaptation" in PyTorch
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

52 lines
1.9 KiB

import os
import sys
sys.path.append(os.path.abspath('.'))
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import SVHN, MNIST
from torchvision import transforms
from models.resnet50off import CNN, Discriminator
from core.trainer import train_target_cnn
from utils.utils import get_logger
from utils.altutils import get_office
def run(args):
if not os.path.exists(args.logdir):
os.makedirs(args.logdir)
logger = get_logger(os.path.join(args.logdir, 'main.log'))
logger.info(args)
# data loaders
dataset_root = os.environ["DATASETDIR"]
source_loader = get_office(dataset_root, args.batch_size, args.src_cat)
target_loader = get_office(dataset_root, args.batch_size, args.tgt_cat)
# train source CNN
source_cnn = CNN(in_channels=args.in_channels).to(args.device)
if os.path.isfile(args.trained):
c = torch.load(args.trained)
source_cnn.load_state_dict(c['model'])
logger.info('Loaded `{}`'.format(args.trained))
# train target CNN
target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device)
target_cnn.load_state_dict(source_cnn.state_dict())
for param in source_cnn.parameters():
param.requires_grad = False
for param in target_cnn.classifier.parameters():
param.requires_grad = False
discriminator = Discriminator(args=args).to(args.device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
target_cnn.encoder.parameters(),
lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
d_optimizer = optim.Adam(
discriminator.parameters(),
lr=args.d_lr, betas=args.betas, weight_decay=args.weight_decay)
train_target_cnn(
source_cnn, target_cnn, discriminator,
criterion, optimizer, d_optimizer,
source_loader, target_loader, target_loader,
args=args)