|
@ -3,6 +3,7 @@ from time import time |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import numpy as np |
|
|
from sklearn.metrics import accuracy_score |
|
|
from sklearn.metrics import accuracy_score |
|
|
|
|
|
from tensorboardX import SummaryWriter |
|
|
import torch |
|
|
import torch |
|
|
|
|
|
|
|
|
from utils import AverageMeter, save |
|
|
from utils import AverageMeter, save |
|
@ -23,9 +24,9 @@ def train_source_cnn( |
|
|
validation = validate( |
|
|
validation = validate( |
|
|
source_cnn, test_loader, criterion, args=args) |
|
|
source_cnn, test_loader, criterion, args=args) |
|
|
log = 'Epoch {}/{} '.format(epoch_i, args.epochs) |
|
|
log = 'Epoch {}/{} '.format(epoch_i, args.epochs) |
|
|
log += 'Train/Loss {:.3f} Train/Acc {:.3f} '.format( |
|
|
|
|
|
|
|
|
log += '| Train/Loss {:.3f} Acc {:.3f} '.format( |
|
|
training['loss'], training['acc']) |
|
|
training['loss'], training['acc']) |
|
|
log += 'Val/Loss {:.3f} Val/Acc {:.3f} '.format( |
|
|
|
|
|
|
|
|
log += '| Val/Loss {:.3f} Acc {:.3f} '.format( |
|
|
validation['loss'], validation['acc']) |
|
|
validation['loss'], validation['acc']) |
|
|
log += 'Time {:.2f}s'.format(time() - start_time) |
|
|
log += 'Time {:.2f}s'.format(time() - start_time) |
|
|
logger.info(log) |
|
|
logger.info(log) |
|
@ -51,9 +52,10 @@ def train_target_cnn( |
|
|
args=None |
|
|
args=None |
|
|
): |
|
|
): |
|
|
validation = validate(source_cnn, target_test_loader, criterion, args=args) |
|
|
validation = validate(source_cnn, target_test_loader, criterion, args=args) |
|
|
log_source = 'Source/Val/Acc {:.3f} '.format(validation['acc']) |
|
|
|
|
|
|
|
|
log_source = 'Source/Acc {:.3f} '.format(validation['acc']) |
|
|
|
|
|
|
|
|
# best_score = None |
|
|
|
|
|
|
|
|
writer = SummaryWriter(args.logdir) |
|
|
|
|
|
best_score = None |
|
|
for epoch_i in range(1, 1 + args.epochs): |
|
|
for epoch_i in range(1, 1 + args.epochs): |
|
|
start_time = time() |
|
|
start_time = time() |
|
|
training = adversarial( |
|
|
training = adversarial( |
|
@ -65,15 +67,37 @@ def train_target_cnn( |
|
|
) |
|
|
) |
|
|
validation = validate( |
|
|
validation = validate( |
|
|
target_cnn, target_test_loader, criterion, args=args) |
|
|
target_cnn, target_test_loader, criterion, args=args) |
|
|
|
|
|
validation2 = validate( |
|
|
|
|
|
target_cnn, target_train_loader, criterion, args=args) |
|
|
log = 'Epoch {}/{} '.format(epoch_i, args.epochs) |
|
|
log = 'Epoch {}/{} '.format(epoch_i, args.epochs) |
|
|
log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format( |
|
|
log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format( |
|
|
training['d/loss'], training['target/loss']) |
|
|
training['d/loss'], training['target/loss']) |
|
|
log += 'Target/Val/Loss {:.3f} Target/Val/Acc {:.3f} '.format( |
|
|
|
|
|
|
|
|
log += '[Val] Target/Loss {:.3f} Target/Acc {:.3f} '.format( |
|
|
validation['loss'], validation['acc']) |
|
|
validation['loss'], validation['acc']) |
|
|
log += log_source |
|
|
log += log_source |
|
|
|
|
|
log += '[Train] Target/Loss {:.3f} Target/Acc {:.3f} '.format( |
|
|
|
|
|
validation2['loss'], validation2['acc']) |
|
|
log += 'Time {:.2f}s'.format(time() - start_time) |
|
|
log += 'Time {:.2f}s'.format(time() - start_time) |
|
|
logger.info(log) |
|
|
logger.info(log) |
|
|
|
|
|
|
|
|
|
|
|
# save |
|
|
|
|
|
is_best = (best_score is None or validation['acc'] > best_score) |
|
|
|
|
|
best_score = validation['acc'] if is_best else best_score |
|
|
|
|
|
state_dict = { |
|
|
|
|
|
'model': target_cnn.state_dict(), |
|
|
|
|
|
'optimizer': optimizer.state_dict(), |
|
|
|
|
|
'epoch': epoch_i, |
|
|
|
|
|
'val/acc': best_score, |
|
|
|
|
|
} |
|
|
|
|
|
save(args.logdir, state_dict, is_best) |
|
|
|
|
|
|
|
|
|
|
|
# tensorboard |
|
|
|
|
|
writer.add_scalar('Adv/D/Loss', training['d/loss'], epoch_i) |
|
|
|
|
|
writer.add_scalar('Adv/Target/Loss', training['target/loss'], epoch_i) |
|
|
|
|
|
writer.add_scalar('Val/Target/Loss', validation['loss'], epoch_i) |
|
|
|
|
|
writer.add_scalar('Val/Target/Acc', validation['acc'], epoch_i) |
|
|
|
|
|
writer.add_scalar('Train/Target/Acc', validation2['acc'], epoch_i) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def adversarial( |
|
|
def adversarial( |
|
|
source_cnn, target_cnn, discriminator, |
|
|
source_cnn, target_cnn, discriminator, |
|
|