From 5c365cf990185af8b5639a0815cfa50e62e38b4b Mon Sep 17 00:00:00 2001 From: fnakamura Date: Wed, 20 Feb 2019 19:45:48 +0900 Subject: [PATCH] logger --- experiment.py | 5 ++++- trainer.py | 8 ++++++-- utils.py | 16 ++++++++++++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/experiment.py b/experiment.py index 785100b..a4ecff9 100644 --- a/experiment.py +++ b/experiment.py @@ -7,12 +7,15 @@ from torchvision.datasets import SVHN, MNIST from torchvision import transforms from models import CNN, Discriminator -from trainer import train_source_cnn, train_target_cnn +from trainer import train_target_cnn +from utils import get_logger def run(args): if not os.path.exists(args.logdir): os.makedirs(args.logdir) + logger = get_logger(os.path.join(args.logdir, 'main.log')) + logger.info(args) # data source_transform = transforms.Compose([ diff --git a/trainer.py b/trainer.py index 44c88d5..c350fcc 100644 --- a/trainer.py +++ b/trainer.py @@ -1,3 +1,4 @@ +from logging import getLogger from time import time import numpy as np @@ -7,6 +8,9 @@ import torch from utils import AverageMeter, save +logger = getLogger('adda.trainer') + + def train_source_cnn( source_cnn, train_loader, test_loader, criterion, optimizer, args=None @@ -24,7 +28,7 @@ def train_source_cnn( log += 'Val/Loss {:.3f} Val/Acc {:.3f} '.format( validation['loss'], validation['acc']) log += 'Time {:.2f}s'.format(time() - start_time) - print(log) + logger.info(log) # save is_best = (best_score is None or validation['acc'] > best_score) @@ -68,7 +72,7 @@ def train_target_cnn( validation['loss'], validation['acc']) log += log_source log += 'Time {:.2f}s'.format(time() - start_time) - print(log) + logger.info(log) def adversarial( diff --git a/utils.py b/utils.py index 6b95707..ae51fc9 100644 --- a/utils.py +++ b/utils.py @@ -11,6 +11,22 @@ def save(log_dir, state_dict, is_best): shutil.copyfile(checkpoint_path, best_model_path) +def get_logger(log_file): + from logging import getLogger, FileHandler, StreamHandler, Formatter, DEBUG, INFO # noqa + fh = FileHandler(log_file) + fh.setLevel(DEBUG) + sh = StreamHandler() + sh.setLevel(INFO) + for handler in [fh, sh]: + formatter = Formatter('%(asctime)s - %(message)s') + handler.setFormatter(formatter) + logger = getLogger('adda') + logger.setLevel(INFO) + logger.addHandler(fh) + logger.addHandler(sh) + return logger + + class AverageMeter(object): """Computes and stores the average and current value https://github.com/pytorch/examples/blob/master/imagenet/main.py#L296