diff --git a/trainer.py b/trainer.py index 47b37ec..309de24 100644 --- a/trainer.py +++ b/trainer.py @@ -3,6 +3,7 @@ from time import time import numpy as np from sklearn.metrics import accuracy_score +from tensorboardX import SummaryWriter import torch from utils import AverageMeter, save @@ -23,9 +24,9 @@ def train_source_cnn( validation = validate( source_cnn, test_loader, criterion, args=args) log = 'Epoch {}/{} '.format(epoch_i, args.epochs) - log += 'Train/Loss {:.3f} Train/Acc {:.3f} '.format( + log += '| Train/Loss {:.3f} Acc {:.3f} '.format( training['loss'], training['acc']) - log += 'Val/Loss {:.3f} Val/Acc {:.3f} '.format( + log += '| Val/Loss {:.3f} Acc {:.3f} '.format( validation['loss'], validation['acc']) log += 'Time {:.2f}s'.format(time() - start_time) logger.info(log) @@ -51,9 +52,10 @@ def train_target_cnn( args=None ): validation = validate(source_cnn, target_test_loader, criterion, args=args) - log_source = 'Source/Val/Acc {:.3f} '.format(validation['acc']) + log_source = 'Source/Acc {:.3f} '.format(validation['acc']) - # best_score = None + writer = SummaryWriter(args.logdir) + best_score = None for epoch_i in range(1, 1 + args.epochs): start_time = time() training = adversarial( @@ -65,15 +67,37 @@ def train_target_cnn( ) validation = validate( target_cnn, target_test_loader, criterion, args=args) + validation2 = validate( + target_cnn, target_train_loader, criterion, args=args) log = 'Epoch {}/{} '.format(epoch_i, args.epochs) log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format( training['d/loss'], training['target/loss']) - log += 'Target/Val/Loss {:.3f} Target/Val/Acc {:.3f} '.format( + log += '[Val] Target/Loss {:.3f} Target/Acc {:.3f} '.format( validation['loss'], validation['acc']) log += log_source + log += '[Train] Target/Loss {:.3f} Target/Acc {:.3f} '.format( + validation2['loss'], validation2['acc']) log += 'Time {:.2f}s'.format(time() - start_time) logger.info(log) + # save + is_best = (best_score is None or validation['acc'] > best_score) + best_score = validation['acc'] if is_best else best_score + state_dict = { + 'model': target_cnn.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch_i, + 'val/acc': best_score, + } + save(args.logdir, state_dict, is_best) + + # tensorboard + writer.add_scalar('Adv/D/Loss', training['d/loss'], epoch_i) + writer.add_scalar('Adv/Target/Loss', training['target/loss'], epoch_i) + writer.add_scalar('Val/Target/Loss', validation['loss'], epoch_i) + writer.add_scalar('Val/Target/Acc', validation['acc'], epoch_i) + writer.add_scalar('Train/Target/Acc', validation2['acc'], epoch_i) + def adversarial( source_cnn, target_cnn, discriminator,