Browse Source

[misc] refactor logs, save target model, tensorboard

master
fnakamura 6 years ago
parent
commit
8ed8034cc7
  1. 34
      trainer.py

34
trainer.py

@ -3,6 +3,7 @@ from time import time
import numpy as np
from sklearn.metrics import accuracy_score
from tensorboardX import SummaryWriter
import torch
from utils import AverageMeter, save
@ -23,9 +24,9 @@ def train_source_cnn(
validation = validate(
source_cnn, test_loader, criterion, args=args)
log = 'Epoch {}/{} '.format(epoch_i, args.epochs)
log += 'Train/Loss {:.3f} Train/Acc {:.3f} '.format(
log += '| Train/Loss {:.3f} Acc {:.3f} '.format(
training['loss'], training['acc'])
log += 'Val/Loss {:.3f} Val/Acc {:.3f} '.format(
log += '| Val/Loss {:.3f} Acc {:.3f} '.format(
validation['loss'], validation['acc'])
log += 'Time {:.2f}s'.format(time() - start_time)
logger.info(log)
@ -51,9 +52,10 @@ def train_target_cnn(
args=None
):
validation = validate(source_cnn, target_test_loader, criterion, args=args)
log_source = 'Source/Val/Acc {:.3f} '.format(validation['acc'])
log_source = 'Source/Acc {:.3f} '.format(validation['acc'])
# best_score = None
writer = SummaryWriter(args.logdir)
best_score = None
for epoch_i in range(1, 1 + args.epochs):
start_time = time()
training = adversarial(
@ -65,15 +67,37 @@ def train_target_cnn(
)
validation = validate(
target_cnn, target_test_loader, criterion, args=args)
validation2 = validate(
target_cnn, target_train_loader, criterion, args=args)
log = 'Epoch {}/{} '.format(epoch_i, args.epochs)
log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format(
training['d/loss'], training['target/loss'])
log += 'Target/Val/Loss {:.3f} Target/Val/Acc {:.3f} '.format(
log += '[Val] Target/Loss {:.3f} Target/Acc {:.3f} '.format(
validation['loss'], validation['acc'])
log += log_source
log += '[Train] Target/Loss {:.3f} Target/Acc {:.3f} '.format(
validation2['loss'], validation2['acc'])
log += 'Time {:.2f}s'.format(time() - start_time)
logger.info(log)
# save
is_best = (best_score is None or validation['acc'] > best_score)
best_score = validation['acc'] if is_best else best_score
state_dict = {
'model': target_cnn.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch_i,
'val/acc': best_score,
}
save(args.logdir, state_dict, is_best)
# tensorboard
writer.add_scalar('Adv/D/Loss', training['d/loss'], epoch_i)
writer.add_scalar('Adv/Target/Loss', training['target/loss'], epoch_i)
writer.add_scalar('Val/Target/Loss', validation['loss'], epoch_i)
writer.add_scalar('Val/Target/Acc', validation['acc'], epoch_i)
writer.add_scalar('Train/Target/Acc', validation2['acc'], epoch_i)
def adversarial(
source_cnn, target_cnn, discriminator,

Loading…
Cancel
Save