Implementation of "Adversarial Discriminative Domain Adaptation" in PyTorch
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

193 lines
6.6 KiB

import os
import sys
sys.path.append(os.path.abspath('.'))
6 years ago
from logging import getLogger
6 years ago
from time import time
import numpy as np
from sklearn.metrics import accuracy_score
import torch
from utils.utils import AverageMeter, save
6 years ago
6 years ago
logger = getLogger('adda.trainer')
6 years ago
def train_source_cnn(
source_cnn, train_loader, test_loader, criterion, optimizer,
args=None
):
best_score = None
for epoch_i in range(1, 1 + args.epochs):
start_time = time()
training = train(
source_cnn, train_loader, criterion, optimizer, args=args)
validation = validate(
source_cnn, test_loader, criterion, args=args)
log = 'Epoch {}/{} '.format(epoch_i, args.epochs)
log += '| Train/Loss {:.3f} Acc {:.3f} '.format(
6 years ago
training['loss'], training['acc'])
log += '| Val/Loss {:.3f} Acc {:.3f} '.format(
6 years ago
validation['loss'], validation['acc'])
log += 'Time {:.2f}s'.format(time() - start_time)
6 years ago
logger.info(log)
6 years ago
# save
is_best = (best_score is None or validation['acc'] > best_score)
best_score = validation['acc'] if is_best else best_score
state_dict = {
'model': source_cnn.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch_i,
'val/acc': best_score,
}
save(args.logdir, state_dict, is_best)
return source_cnn
def train_target_cnn(
source_cnn, target_cnn, discriminator,
criterion, optimizer, d_optimizer,
source_train_loader, target_train_loader, target_test_loader,
args=None
):
validation = validate(source_cnn, target_test_loader, criterion, args=args)
log_source = 'Source/Acc {:.3f} '.format(validation['acc'])
6 years ago
best_score = None
6 years ago
for epoch_i in range(1, 1 + args.epochs):
start_time = time()
training = adversarial(
source_cnn, target_cnn, discriminator,
source_train_loader, target_train_loader,
criterion, criterion,
optimizer, d_optimizer,
args=args
)
validation = validate(
target_cnn, target_test_loader, criterion, args=args)
validation2 = validate(
target_cnn, target_train_loader, criterion, args=args)
6 years ago
log = 'Epoch {}/{} '.format(epoch_i, args.epochs)
log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format(
training['d/loss'], training['target/loss'])
log += '[Val] Target/Loss {:.3f} Target/Acc {:.3f} '.format(
6 years ago
validation['loss'], validation['acc'])
log += log_source
log += '[Train] Target/Loss {:.3f} Target/Acc {:.3f} '.format(
validation2['loss'], validation2['acc'])
6 years ago
log += 'Time {:.2f}s'.format(time() - start_time)
6 years ago
logger.info(log)
6 years ago
# save
is_best = (best_score is None or validation['acc'] > best_score)
best_score = validation['acc'] if is_best else best_score
state_dict = {
'model': target_cnn.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch_i,
'val/acc': best_score,
}
save(args.logdir, state_dict, is_best)
6 years ago
def adversarial(
source_cnn, target_cnn, discriminator,
source_loader, target_loader,
criterion, d_criterion,
optimizer, d_optimizer,
args=None
):
source_cnn.eval()
6 years ago
target_cnn.encoder.train()
6 years ago
discriminator.train()
losses, d_losses = AverageMeter(), AverageMeter()
n_iters = min(len(source_loader), len(target_loader))
source_iter, target_iter = iter(source_loader), iter(target_loader)
for iter_i in range(n_iters):
source_data, source_target = source_iter.next()
target_data, target_target = target_iter.next()
source_data = source_data.to(args.device)
target_data = target_data.to(args.device)
bs = source_data.size(0)
D_input_source = source_cnn.encoder(source_data)
D_input_target = target_cnn.encoder(target_data)
D_target_source = torch.tensor(
[0] * bs, dtype=torch.long).to(args.device)
D_target_target = torch.tensor(
[1] * bs, dtype=torch.long).to(args.device)
# train Discriminator
D_output_source = discriminator(D_input_source)
D_output_target = discriminator(D_input_target)
6 years ago
D_output = torch.cat([D_output_source, D_output_target], dim=0)
D_target = torch.cat([D_target_source, D_target_target], dim=0)
d_loss = criterion(D_output, D_target)
6 years ago
d_optimizer.zero_grad()
6 years ago
d_loss.backward()
6 years ago
d_optimizer.step()
d_losses.update(d_loss.item(), bs)
# train Target
D_input_target = target_cnn.encoder(target_data)
D_output_target = discriminator(D_input_target)
loss = criterion(D_output_target, D_target_source)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.update(loss.item(), bs)
return {'d/loss': d_losses.avg, 'target/loss': losses.avg}
def step(model, data, target, criterion, args):
data, target = data.to(args.device), target.to(args.device)
output = model(data)
loss = criterion(output, target)
return output, loss
def train(model, dataloader, criterion, optimizer, args=None):
model.train()
losses = AverageMeter()
targets, probas = [], []
for i, (data, target) in enumerate(dataloader):
bs = target.size(0)
output, loss = step(model, data, target, criterion, args)
output = torch.softmax(output, dim=1) # NOTE
losses.update(loss.item(), bs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
targets.extend(target.cpu().detach().numpy().tolist())
probas.extend(output.cpu().detach().numpy().tolist())
probas = np.asarray(probas)
preds = np.argmax(probas, axis=1)
acc = accuracy_score(targets, preds)
return {
'loss': losses.avg, 'acc': acc,
}
def validate(model, dataloader, criterion, args=None):
model.eval()
losses = AverageMeter()
targets, probas = [], []
with torch.no_grad():
for iter_i, (data, target) in enumerate(dataloader):
bs = target.size(0)
output, loss = step(model, data, target, criterion, args)
output = torch.softmax(output, dim=1) # NOTE: check
losses.update(loss.item(), bs)
targets.extend(target.cpu().numpy().tolist())
probas.extend(output.cpu().numpy().tolist())
probas = np.asarray(probas)
preds = np.argmax(probas, axis=1)
acc = accuracy_score(targets, preds)
return {
'loss': losses.avg, 'acc': acc,
}