From dc8c2b831ac78a7c79894949322c6df1dfaac14f Mon Sep 17 00:00:00 2001 From: wogong Date: Wed, 6 Nov 2019 22:19:38 +0800 Subject: [PATCH] update test function, using sum batch size as total num instead of len(dataloader) --- core/test.py | 55 ++++++++++++---------------------------------------- 1 file changed, 12 insertions(+), 43 deletions(-) diff --git a/core/test.py b/core/test.py index 2935762..799e745 100644 --- a/core/test.py +++ b/core/test.py @@ -1,48 +1,16 @@ import torch.utils.data import torch.nn as nn - -def test_from_save(model, saved_model, data_loader, device): - """Evaluate classifier for source domain.""" - # set eval state for Dropout and BN layers - classifier = model.load_state_dict(torch.load(saved_model)) - classifier.eval() - - # init loss and accuracy - loss = 0.0 - acc = 0.0 - - # set loss function - criterion = nn.NLLLoss() - - # evaluate network - for (images, labels) in data_loader: - images = images.to(device) - labels = labels.to(device) #labels = labels.squeeze(1) - preds = classifier(images) - - criterion(preds, labels) - - loss += criterion(preds, labels).data.item() - - pred_cls = preds.data.max(1)[1] - acc += pred_cls.eq(labels.data).cpu().sum() - - loss /= len(data_loader) - acc /= len(data_loader.dataset) - - print("Avg Loss = {}, Avg Accuracy = {:.2%}".format(loss, acc)) - - def test(model, data_loader, device, flag): """Evaluate model for dataset.""" # set eval state for Dropout and BN layers model.eval() # init loss and accuracy - loss = 0.0 - acc = 0.0 - acc_domain = 0.0 + loss_ = 0.0 + acc_ = 0.0 + acc_domain_ = 0.0 + n_total = 0 # set loss function criterion = nn.CrossEntropyLoss() @@ -59,17 +27,18 @@ def test(model, data_loader, device, flag): preds, domain = model(images, alpha=0) - loss += criterion(preds, labels).data.item() + loss_ += criterion(preds, labels).item() pred_cls = preds.data.max(1)[1] pred_domain = domain.data.max(1)[1] - acc += pred_cls.eq(labels.data).sum().item() - acc_domain += pred_domain.eq(labels_domain.data).sum().item() + acc_ += pred_cls.eq(labels.data).sum().item() + acc_domain_ += pred_domain.eq(labels_domain.data).sum().item() + n_total += size - loss /= len(data_loader) - acc /= len(data_loader.dataset) - acc_domain /= len(data_loader.dataset) + loss = loss_ / n_total + acc = acc_ / n_total + acc_domain = acc_domain_ / n_total - #print("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_domain)) + print("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, {}/{}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_, n_total, acc_domain)) return loss, acc, acc_domain