Browse Source

Looging improvements

master
Fazil Altinel 4 years ago
parent
commit
95a709f395
  1. 4
      core/test.py
  2. 14
      core/train.py
  3. 7
      experiments/mnist_mnistm.py
  4. 33
      utils/altutils.py

4
core/test.py

@ -1,7 +1,7 @@
import torch.utils.data import torch.utils.data
import torch.nn as nn import torch.nn as nn
def test(model, data_loader, device, flag):
def test(model, data_loader, device, loggi, flag):
"""Evaluate model for dataset.""" """Evaluate model for dataset."""
# set eval state for Dropout and BN layers # set eval state for Dropout and BN layers
model.eval() model.eval()
@ -39,6 +39,6 @@ def test(model, data_loader, device, flag):
acc = acc_ / n_total acc = acc_ / n_total
acc_domain = acc_domain_ / n_total acc_domain = acc_domain_ / n_total
print("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, {}/{}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_, n_total, acc_domain))
loggi.info("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, {}/{}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_, n_total, acc_domain))
return loss, acc, acc_domain return loss, acc, acc_domain

14
core/train.py

@ -1,7 +1,7 @@
"""Train dann.""" """Train dann."""
import numpy as np import numpy as np
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
@ -113,7 +113,7 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e
return model return model
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger=None):
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, loggi, logger=None):
"""Train dann.""" """Train dann."""
#################### ####################
# 1. setup network # # 1. setup network #
@ -215,8 +215,8 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
# eval model # eval model
if ((epoch + 1) % params.eval_step == 0): if ((epoch + 1) % params.eval_step == 0):
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target')
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source')
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, loggi, flag='target')
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, loggi, flag='source')
if tgt_acc > bestAcc: if tgt_acc > bestAcc:
bestAcc = tgt_acc bestAcc = tgt_acc
bestAccS = src_acc bestAccS = src_acc
@ -232,9 +232,9 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
# save final model # save final model
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt") save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt")
print('============ Summary ============= \n')
print('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS))
print('Accuracy of the %s dataset: %f' % (params.tgtc_dataset, bestAcc))
loggi.info('============ Summary ============= \n')
loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS))
loggi.info('Accuracy of the %s dataset: %f' % (params.tgtc_dataset, bestAcc))
return model return model

7
experiments/mnist_mnistm.py

@ -5,6 +5,7 @@ import torch
from models.model import MNISTmodel, MNISTmodel_plain from models.model import MNISTmodel, MNISTmodel_plain
from core.train import train_dann from core.train import train_dann
from utils.utils import get_data_loader, init_model, init_random_seed from utils.utils import get_data_loader, init_model, init_random_seed
from utils.altutils import setLogger
class Config(object): class Config(object):
@ -63,6 +64,10 @@ class Config(object):
params = Config() params = Config()
currentDir = os.path.dirname(os.path.realpath(__file__))
logFile = os.path.join(currentDir+'/../', 'dann-{}-{}.log'.format(params.src_dataset, params.tgt_dataset))
loggi = setLogger(logFile)
# init random seed # init random seed
init_random_seed(params.manual_seed) init_random_seed(params.manual_seed)
@ -81,4 +86,4 @@ dann = init_model(net=MNISTmodel(), restore=None)
# train dann model # train dann model
print("Training dann model") print("Training dann model")
if not (dann.restored and params.dann_restore): if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device)
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, loggi)

33
utils/altutils.py

@ -0,0 +1,33 @@
import configparser
import logging
def readConfigFile(filePath):
"""
Read config file
Args:
filePath ([str]): path to config file
Returns:
[Obj]: config object
"""
config = configparser.ConfigParser()
config.read(filePath)
return config
def setLogger(logFilePath):
"""
Set logger
Args:
logFilePath ([str]): path to log file
Returns:
[obj]: logger object
"""
logHandler = [logging.FileHandler(logFilePath), logging.StreamHandler()]
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", handlers=logHandler)
logger = logging.getLogger()
return logger
Loading…
Cancel
Save