Browse Source

update test function, using sum batch size as total num instead of len(dataloader)

master
wogong 5 years ago
parent
commit
dc8c2b831a
  1. 55
      core/test.py

55
core/test.py

@ -1,48 +1,16 @@
import torch.utils.data
import torch.nn as nn
def test_from_save(model, saved_model, data_loader, device):
"""Evaluate classifier for source domain."""
# set eval state for Dropout and BN layers
classifier = model.load_state_dict(torch.load(saved_model))
classifier.eval()
# init loss and accuracy
loss = 0.0
acc = 0.0
# set loss function
criterion = nn.NLLLoss()
# evaluate network
for (images, labels) in data_loader:
images = images.to(device)
labels = labels.to(device) #labels = labels.squeeze(1)
preds = classifier(images)
criterion(preds, labels)
loss += criterion(preds, labels).data.item()
pred_cls = preds.data.max(1)[1]
acc += pred_cls.eq(labels.data).cpu().sum()
loss /= len(data_loader)
acc /= len(data_loader.dataset)
print("Avg Loss = {}, Avg Accuracy = {:.2%}".format(loss, acc))
def test(model, data_loader, device, flag):
"""Evaluate model for dataset."""
# set eval state for Dropout and BN layers
model.eval()
# init loss and accuracy
loss = 0.0
acc = 0.0
acc_domain = 0.0
loss_ = 0.0
acc_ = 0.0
acc_domain_ = 0.0
n_total = 0
# set loss function
criterion = nn.CrossEntropyLoss()
@ -59,17 +27,18 @@ def test(model, data_loader, device, flag):
preds, domain = model(images, alpha=0)
loss += criterion(preds, labels).data.item()
loss_ += criterion(preds, labels).item()
pred_cls = preds.data.max(1)[1]
pred_domain = domain.data.max(1)[1]
acc += pred_cls.eq(labels.data).sum().item()
acc_domain += pred_domain.eq(labels_domain.data).sum().item()
acc_ += pred_cls.eq(labels.data).sum().item()
acc_domain_ += pred_domain.eq(labels_domain.data).sum().item()
n_total += size
loss /= len(data_loader)
acc /= len(data_loader.dataset)
acc_domain /= len(data_loader.dataset)
loss = loss_ / n_total
acc = acc_ / n_total
acc_domain = acc_domain_ / n_total
#print("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_domain))
print("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, {}/{}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_, n_total, acc_domain))
return loss, acc, acc_domain

Loading…
Cancel
Save