Browse Source

logger

master
fnakamura 6 years ago
parent
commit
5c365cf990
  1. 5
      experiment.py
  2. 8
      trainer.py
  3. 16
      utils.py

5
experiment.py

@ -7,12 +7,15 @@ from torchvision.datasets import SVHN, MNIST
from torchvision import transforms from torchvision import transforms
from models import CNN, Discriminator from models import CNN, Discriminator
from trainer import train_source_cnn, train_target_cnn
from trainer import train_target_cnn
from utils import get_logger
def run(args): def run(args):
if not os.path.exists(args.logdir): if not os.path.exists(args.logdir):
os.makedirs(args.logdir) os.makedirs(args.logdir)
logger = get_logger(os.path.join(args.logdir, 'main.log'))
logger.info(args)
# data # data
source_transform = transforms.Compose([ source_transform = transforms.Compose([

8
trainer.py

@ -1,3 +1,4 @@
from logging import getLogger
from time import time from time import time
import numpy as np import numpy as np
@ -7,6 +8,9 @@ import torch
from utils import AverageMeter, save from utils import AverageMeter, save
logger = getLogger('adda.trainer')
def train_source_cnn( def train_source_cnn(
source_cnn, train_loader, test_loader, criterion, optimizer, source_cnn, train_loader, test_loader, criterion, optimizer,
args=None args=None
@ -24,7 +28,7 @@ def train_source_cnn(
log += 'Val/Loss {:.3f} Val/Acc {:.3f} '.format( log += 'Val/Loss {:.3f} Val/Acc {:.3f} '.format(
validation['loss'], validation['acc']) validation['loss'], validation['acc'])
log += 'Time {:.2f}s'.format(time() - start_time) log += 'Time {:.2f}s'.format(time() - start_time)
print(log)
logger.info(log)
# save # save
is_best = (best_score is None or validation['acc'] > best_score) is_best = (best_score is None or validation['acc'] > best_score)
@ -68,7 +72,7 @@ def train_target_cnn(
validation['loss'], validation['acc']) validation['loss'], validation['acc'])
log += log_source log += log_source
log += 'Time {:.2f}s'.format(time() - start_time) log += 'Time {:.2f}s'.format(time() - start_time)
print(log)
logger.info(log)
def adversarial( def adversarial(

16
utils.py

@ -11,6 +11,22 @@ def save(log_dir, state_dict, is_best):
shutil.copyfile(checkpoint_path, best_model_path) shutil.copyfile(checkpoint_path, best_model_path)
def get_logger(log_file):
from logging import getLogger, FileHandler, StreamHandler, Formatter, DEBUG, INFO # noqa
fh = FileHandler(log_file)
fh.setLevel(DEBUG)
sh = StreamHandler()
sh.setLevel(INFO)
for handler in [fh, sh]:
formatter = Formatter('%(asctime)s - %(message)s')
handler.setFormatter(formatter)
logger = getLogger('adda')
logger.setLevel(INFO)
logger.addHandler(fh)
logger.addHandler(sh)
return logger
class AverageMeter(object): class AverageMeter(object):
"""Computes and stores the average and current value """Computes and stores the average and current value
https://github.com/pytorch/examples/blob/master/imagenet/main.py#L296 https://github.com/pytorch/examples/blob/master/imagenet/main.py#L296

Loading…
Cancel
Save