Browse Source

Looging improvements

master
Fazil Altinel 4 years ago
parent
commit
95a709f395
  1. 4
      core/test.py
  2. 14
      core/train.py
  3. 7
      experiments/mnist_mnistm.py
  4. 33
      utils/altutils.py

4
core/test.py

@ -1,7 +1,7 @@
import torch.utils.data
import torch.nn as nn
def test(model, data_loader, device, flag):
def test(model, data_loader, device, loggi, flag):
"""Evaluate model for dataset."""
# set eval state for Dropout and BN layers
model.eval()
@ -39,6 +39,6 @@ def test(model, data_loader, device, flag):
acc = acc_ / n_total
acc_domain = acc_domain_ / n_total
print("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, {}/{}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_, n_total, acc_domain))
loggi.info("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, {}/{}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_, n_total, acc_domain))
return loss, acc, acc_domain

14
core/train.py

@ -1,7 +1,7 @@
"""Train dann."""
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
@ -113,7 +113,7 @@ def train_src(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_e
return model
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, logger=None):
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, loggi, logger=None):
"""Train dann."""
####################
# 1. setup network #
@ -215,8 +215,8 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
# eval model
if ((epoch + 1) % params.eval_step == 0):
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, flag='target')
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, flag='source')
tgt_test_loss, tgt_acc, tgt_acc_domain = test(model, tgt_data_loader_eval, device, loggi, flag='target')
src_test_loss, src_acc, src_acc_domain = test(model, src_data_loader, device, loggi, flag='source')
if tgt_acc > bestAcc:
bestAcc = tgt_acc
bestAccS = src_acc
@ -232,9 +232,9 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
# save final model
save_model(model, params.model_root, params.src_dataset + '-' + params.tgt_dataset + "-dann-final.pt")
print('============ Summary ============= \n')
print('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS))
print('Accuracy of the %s dataset: %f' % (params.tgtc_dataset, bestAcc))
loggi.info('============ Summary ============= \n')
loggi.info('Accuracy of the %s dataset: %f' % (params.src_dataset, bestAccS))
loggi.info('Accuracy of the %s dataset: %f' % (params.tgtc_dataset, bestAcc))
return model

7
experiments/mnist_mnistm.py

@ -5,6 +5,7 @@ import torch
from models.model import MNISTmodel, MNISTmodel_plain
from core.train import train_dann
from utils.utils import get_data_loader, init_model, init_random_seed
from utils.altutils import setLogger
class Config(object):
@ -63,6 +64,10 @@ class Config(object):
params = Config()
currentDir = os.path.dirname(os.path.realpath(__file__))
logFile = os.path.join(currentDir+'/../', 'dann-{}-{}.log'.format(params.src_dataset, params.tgt_dataset))
loggi = setLogger(logFile)
# init random seed
init_random_seed(params.manual_seed)
@ -81,4 +86,4 @@ dann = init_model(net=MNISTmodel(), restore=None)
# train dann model
print("Training dann model")
if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device)
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device, loggi)

33
utils/altutils.py

@ -0,0 +1,33 @@
import configparser
import logging
def readConfigFile(filePath):
"""
Read config file
Args:
filePath ([str]): path to config file
Returns:
[Obj]: config object
"""
config = configparser.ConfigParser()
config.read(filePath)
return config
def setLogger(logFilePath):
"""
Set logger
Args:
logFilePath ([str]): path to log file
Returns:
[obj]: logger object
"""
logHandler = [logging.FileHandler(logFilePath), logging.StreamHandler()]
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", handlers=logHandler)
logger = logging.getLogger()
return logger
Loading…
Cancel
Save