Implementation of "Adversarial Discriminative Domain Adaptation" in PyTorch
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

73 lines
2.6 KiB

6 years ago
import os
import sys
sys.path.append(os.path.abspath('.'))
6 years ago
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import SVHN, MNIST
from torchvision import transforms
from models.models import CNN, Discriminator
from core.trainer import train_target_cnn
from utils.utils import get_logger
6 years ago
def run(args):
if not os.path.exists(args.logdir):
os.makedirs(args.logdir)
6 years ago
logger = get_logger(os.path.join(args.logdir, 'main.log'))
logger.info(args)
6 years ago
dataset_root = os.environ["DATASETDIR"]
6 years ago
# data
source_transform = transforms.Compose([
# transforms.Grayscale(),
transforms.ToTensor()]
)
target_transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1))
])
source_dataset_train = SVHN(
'input/', 'train', transform=source_transform, download=True)
6 years ago
target_dataset_train = MNIST(
'input/', 'train', transform=target_transform, download=True)
6 years ago
target_dataset_test = MNIST(
'input/', 'test', transform=target_transform, download=True)
6 years ago
source_train_loader = DataLoader(
source_dataset_train, args.batch_size, shuffle=True,
drop_last=True,
num_workers=args.n_workers)
target_train_loader = DataLoader(
target_dataset_train, args.batch_size, shuffle=True,
drop_last=True,
num_workers=args.n_workers)
target_test_loader = DataLoader(
target_dataset_test, args.batch_size, shuffle=False,
num_workers=args.n_workers)
# train source CNN
source_cnn = CNN(in_channels=args.in_channels).to(args.device)
if os.path.isfile(args.trained):
c = torch.load(args.trained)
source_cnn.load_state_dict(c['model'])
6 years ago
logger.info('Loaded `{}`'.format(args.trained))
6 years ago
# train target CNN
target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device)
target_cnn.load_state_dict(source_cnn.state_dict())
discriminator = Discriminator(args=args).to(args.device)
criterion = nn.CrossEntropyLoss()
6 years ago
optimizer = optim.Adam(
6 years ago
target_cnn.encoder.parameters(),
6 years ago
lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
d_optimizer = optim.Adam(
6 years ago
discriminator.parameters(),
6 years ago
lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
6 years ago
train_target_cnn(
source_cnn, target_cnn, discriminator,
criterion, optimizer, d_optimizer,
source_train_loader, target_train_loader, target_test_loader,
args=args)