From 95a709f395152835a268bbb27b29608bbc7dbb8a Mon Sep 17 00:00:00 2001 From: fazilaltinel Date: Thu, 3 Dec 2020 14:46:50 +0300 Subject: [PATCH] Looging improvements --- core/test.py | 4 ++-- core/train.py | 14 +++++++------- experiments/mnist_mnistm.py | 7 ++++++- utils/altutils.py | 33 +++++++++++++++++++++++++++++++++ 4 files changed, 48 insertions(+), 10 deletions(-) create mode 100644 utils/altutils.py diff --git a/core/test.py b/core/test.py index 799e745..26fcb12 100644 --- a/core/test.py +++ b/core/test.py @@ -1,7 +1,7 @@ import torch.utils.data import torch.nn as nn -def test(model, data_loader, device, flag): +def test(model, data_loader, device, loggi, flag): """Evaluate model for dataset.""" # set eval state for Dropout and BN layers model.eval() @@ -39,6 +39,6 @@ def test(model, data_loader, device, flag): acc = acc_ / n_total acc_domain = acc_domain_ / n_total - print("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, {}/{}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_, n_total, acc_domain)) + loggi.info("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, {}/{}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_, n_total, acc_domain)) return loss, acc, acc_domain diff --git a/core/train.py b/core/train.py index 27fb4f4..976a191 100644 --- a/core/train.py +++ b/core/train.py @@ -1,7 +1,7 @@ """Train dann.""" import numpy as np - +import os import torch import torch.nn as nn import torch.optim as optim @@ -113,7 +113,7 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e return model -def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger=None): +def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, loggi, logger=None): """Train dann.""" #################### # 1. setup network # @@ -215,8 +215,8 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ # eval model if ((epoch + 1) % params.eval_step == 0): - tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target') - src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source') + tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, loggi, flag='target') + src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, loggi, flag='source') if tgt_acc > bestAcc: bestAcc = tgt_acc bestAccS = src_acc @@ -232,9 +232,9 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ # save final model save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") - print('============ Summary ============= \n') - print('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS)) - print('Accuracy of the %s dataset: %f' % (params.tgtc_dataset, bestAcc)) + loggi.info('============ Summary ============= \n') + loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS)) + loggi.info('Accuracy of the %s dataset: %f' % (params.tgtc_dataset, bestAcc)) return model diff --git a/experiments/mnist_mnistm.py b/experiments/mnist_mnistm.py index 37985c4..692c414 100644 --- a/experiments/mnist_mnistm.py +++ b/experiments/mnist_mnistm.py @@ -5,6 +5,7 @@ import torch from models.model import MNISTmodel, MNISTmodel_plain from core.train import train_dann from utils.utils import get_data_loader, init_model, init_random_seed +from utils.altutils import setLogger class Config(object): @@ -63,6 +64,10 @@ class Config(object): params = Config() +currentDir = os.path.dirname(os.path.realpath(__file__)) +logFile = os.path.join(currentDir+'/../', 'dann-{}-{}.log'.format(params.src_dataset, params.tgt_dataset)) +loggi = setLogger(logFile) + # init random seed init_random_seed(params.manual_seed) @@ -81,4 +86,4 @@ dann = init_model(net=MNISTmodel(), restore=None) # train dann model print("Training dann model") if not (dann.restored and params.dann_restore): - dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device) \ No newline at end of file + dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, loggi) \ No newline at end of file diff --git a/utils/altutils.py b/utils/altutils.py new file mode 100644 index 0000000..1c22525 --- /dev/null +++ b/utils/altutils.py @@ -0,0 +1,33 @@ +import configparser +import logging + + +def readConfigFile(filePath): + """ + Read config file + + Args: + filePath ([str]): path to config file + + Returns: + [Obj]: config object + """ + config = configparser.ConfigParser() + config.read(filePath) + return config + + +def setLogger(logFilePath): + """ + Set logger + + Args: + logFilePath ([str]): path to log file + + Returns: + [obj]: logger object + """ + logHandler = [logging.FileHandler(logFilePath), logging.StreamHandler()] + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", handlers=logHandler) + logger = logging.getLogger() + return logger \ No newline at end of file