|
|
@ -1,48 +1,16 @@ |
|
|
|
import torch.utils.data |
|
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
|
|
def test_from_save(model, saved_model, data_loader, device): |
|
|
|
"""Evaluate classifier for source domain.""" |
|
|
|
# set eval state for Dropout and BN layers |
|
|
|
classifier = model.load_state_dict(torch.load(saved_model)) |
|
|
|
classifier.eval() |
|
|
|
|
|
|
|
# init loss and accuracy |
|
|
|
loss = 0.0 |
|
|
|
acc = 0.0 |
|
|
|
|
|
|
|
# set loss function |
|
|
|
criterion = nn.NLLLoss() |
|
|
|
|
|
|
|
# evaluate network |
|
|
|
for (images, labels) in data_loader: |
|
|
|
images = images.to(device) |
|
|
|
labels = labels.to(device) #labels = labels.squeeze(1) |
|
|
|
preds = classifier(images) |
|
|
|
|
|
|
|
criterion(preds, labels) |
|
|
|
|
|
|
|
loss += criterion(preds, labels).data.item() |
|
|
|
|
|
|
|
pred_cls = preds.data.max(1)[1] |
|
|
|
acc += pred_cls.eq(labels.data).cpu().sum() |
|
|
|
|
|
|
|
loss /= len(data_loader) |
|
|
|
acc /= len(data_loader.dataset) |
|
|
|
|
|
|
|
print("Avg Loss = {}, Avg Accuracy = {:.2%}".format(loss, acc)) |
|
|
|
|
|
|
|
|
|
|
|
def test(model, data_loader, device, flag): |
|
|
|
"""Evaluate model for dataset.""" |
|
|
|
# set eval state for Dropout and BN layers |
|
|
|
model.eval() |
|
|
|
|
|
|
|
# init loss and accuracy |
|
|
|
loss = 0.0 |
|
|
|
acc = 0.0 |
|
|
|
acc_domain = 0.0 |
|
|
|
loss_ = 0.0 |
|
|
|
acc_ = 0.0 |
|
|
|
acc_domain_ = 0.0 |
|
|
|
n_total = 0 |
|
|
|
|
|
|
|
# set loss function |
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
@ -59,17 +27,18 @@ def test(model, data_loader, device, flag): |
|
|
|
|
|
|
|
preds, domain = model(images, alpha=0) |
|
|
|
|
|
|
|
loss += criterion(preds, labels).data.item() |
|
|
|
loss_ += criterion(preds, labels).item() |
|
|
|
|
|
|
|
pred_cls = preds.data.max(1)[1] |
|
|
|
pred_domain = domain.data.max(1)[1] |
|
|
|
acc += pred_cls.eq(labels.data).sum().item() |
|
|
|
acc_domain += pred_domain.eq(labels_domain.data).sum().item() |
|
|
|
acc_ += pred_cls.eq(labels.data).sum().item() |
|
|
|
acc_domain_ += pred_domain.eq(labels_domain.data).sum().item() |
|
|
|
n_total += size |
|
|
|
|
|
|
|
loss /= len(data_loader) |
|
|
|
acc /= len(data_loader.dataset) |
|
|
|
acc_domain /= len(data_loader.dataset) |
|
|
|
loss = loss_ / n_total |
|
|
|
acc = acc_ / n_total |
|
|
|
acc_domain = acc_domain_ / n_total |
|
|
|
|
|
|
|
#print("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_domain)) |
|
|
|
print("Avg Loss = {:.6f}, Avg Accuracy = {:.2%}, {}/{}, Avg Domain Accuracy = {:2%}".format(loss, acc, acc_, n_total, acc_domain)) |
|
|
|
|
|
|
|
return loss, acc, acc_domain |
|
|
|