Browse Source

rename test function

master
wogong 6 years ago
parent
commit
ab4a2bf5dc
  1. 6
      core/dann.py
  2. 8
      core/pretrain.py
  3. 2
      core/test.py
  4. 9
      experiments/mnist_mnistm.py
  5. 10
      experiments/office.py
  6. 26
      experiments/office31_10.py
  7. 26
      experiments/svhn_mnist.py

6
core/dann.py

@ -6,7 +6,7 @@ import torch
import torch.nn as nn
import torch.optim as optim
from core.test import eval
from core.test import test
from utils.utils import save_model
import torch.backends.cudnn as cudnn
@ -103,9 +103,9 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
# eval model
if ((epoch + 1) % params.eval_step == 0):
print("eval on target domain")
eval(model, tgt_data_loader, device, flag='target')
test(model, tgt_data_loader, device, flag='target')
print("eval on source domain")
eval(model, src_data_loader, device, flag='source')
test(model, src_data_loader, device, flag='source')
# save model parameters
if ((epoch + 1) % params.save_step == 0):

8
core/pretrain.py

@ -4,7 +4,7 @@ import torch.nn as nn
import torch.optim as optim
from utils.utils import save_model
from core.test import eval
from core.test import test
def train_src(model, params, data_loader, device):
@ -34,8 +34,8 @@ def train_src(model, params, data_loader, device):
optimizer.zero_grad()
# compute loss for critic
preds = model(images)
loss = loss_class(preds, labels)
preds_class, _ = model(images)
loss = loss_class(preds_class, labels)
# optimize source classifier
loss.backward()
@ -48,7 +48,7 @@ def train_src(model, params, data_loader, device):
# eval model on test set
if ((epoch + 1) % params.eval_step_src == 0):
eval(model, data_loader, flag='source')
test(model, data_loader, flag='source')
model.train()
# save model parameters

2
core/test.py

@ -34,7 +34,7 @@ def test_from_save(model, saved_model, data_loader, device):
print("Avg Loss = {}, Avg Accuracy = {:.2%}".format(loss, acc))
def eval(model, data_loader, device, flag):
def test(model, data_loader, device, flag):
"""Evaluate model for dataset."""
# set eval state for Dropout and BN layers
model.eval()

9
experiments/mnist_mnistm.py

@ -52,6 +52,7 @@ class Config(object):
# params for optimizing models
lr = 2e-4
params = Config()
# init random seed
@ -69,10 +70,4 @@ dann = init_model(net=MNISTmodel(), restore=None)
# train dann model
print("Training dann model")
if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval)
# eval dann model
print("Evaluating dann for source domain {}".format(params.src_dataset))
eval(dann, src_data_loader_eval)
print("Evaluating dann for target domain {}".format(params.tgt_dataset))
eval(dann, tgt_data_loader_eval)
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval)

10
experiments/office.py

@ -5,7 +5,7 @@ import torch
sys.path.append('../')
from core.dann import train_dann
from core.test import eval
from core.test import test
from models.model import AlexModel
from utils.utils import get_data_loader, init_model, init_random_seed
@ -41,7 +41,7 @@ class Config(object):
num_epochs = 1000
log_step = 10 # iters
save_step = 500
eval_step = 5 # epochs
eval_step = 10 # epochs
manual_seed = 8888
alpha = 0
@ -71,10 +71,4 @@ print("Start training dann model.")
if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader, device)
# eval dann model
print("Evaluating dann for source domain")
eval(dann, src_data_loader)
print("Evaluating dann for target domain")
eval(dann, tgt_data_loader)
print('done')

26
experiments/office31_10.py

@ -3,7 +3,7 @@ import sys
sys.path.append('../')
from core.dann import train_dann
from core.test import eval
from core.test import test
from models.model import AlexModel
from utils.utils import get_data_loader, init_model, init_random_seed
@ -12,8 +12,7 @@ from utils.utils import get_data_loader, init_model, init_random_seed
class Config(object):
# params for path
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets'))
model_root = os.path.expanduser(
os.path.join('~', 'Models', 'pytorch-DANN'))
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN'))
# params for datasets and data loader
batch_size = 32
@ -21,14 +20,12 @@ class Config(object):
# params for source dataset
src_dataset = "amazon31"
src_model_trained = True
src_classifier_restore = os.path.join(
model_root, src_dataset + '-source-classifier-final.pt')
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt')
# params for target dataset
tgt_dataset = "webcam10"
tgt_model_trained = True
dann_restore = os.path.join(
model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt')
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt')
# params for pretrain
num_epochs_src = 100
@ -57,10 +54,8 @@ params = Config()
init_random_seed(params.manual_seed)
# load dataset
src_data_loader = get_data_loader(
params.src_dataset, params.dataset_root, params.batch_size)
tgt_data_loader = get_data_loader(
params.tgt_dataset, params.dataset_root, params.batch_size)
src_data_loader = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size)
tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size)
# load dann model
dann = init_model(net=AlexModel(), restore=None)
@ -69,13 +64,6 @@ dann = init_model(net=AlexModel(), restore=None)
print("Start training dann model.")
if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader,
tgt_data_loader, tgt_data_loader)
# eval dann model
print("Evaluating dann for source domain")
eval(dann, src_data_loader)
print("Evaluating dann for target domain")
eval(dann, tgt_data_loader)
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader)
print('done')

26
experiments/svhn_mnist.py

@ -1,16 +1,21 @@
import os
import sys
import torch
sys.path.append('../')
from models.model import SVHNmodel
from core.dann import train_dann
from core.pretrain import train_src
from core.test import test
from utils.utils import get_data_loader, init_model, init_random_seed
class Config(object):
# params for path
dataset_root = os.path.expanduser(os.path.join('~', 'Datasets'))
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN'))
model_name = "svhn-mnist"
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN', model_name))
# params for datasets and data loader
batch_size = 128
@ -32,11 +37,12 @@ class Config(object):
eval_step_src = 20
# params for training dann
gpu_id = '0'
## for digit
num_epochs = 200
log_step = 20
save_step = 50
log_step = 50
save_step = 100
eval_step = 5
## for office
@ -45,17 +51,21 @@ class Config(object):
# save_step = 500
# eval_step = 5 # epochs
manual_seed = 8888
manual_seed = None
alpha = 0
# params for optimizing models
lr = 2e-4
params = Config()
# init random seed
init_random_seed(params.manual_seed)
# init device
device = torch.device("cuda:" + params.gpu_id if torch.cuda.is_available() else "cpu")
# load dataset
src_data_loader = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=True)
src_data_loader_eval = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=False)
@ -68,10 +78,4 @@ dann = init_model(net=SVHNmodel(), restore=None)
# train dann model
print("Training dann model")
if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval)
# eval dann model
print("Evaluating dann for source domain {}".format(params.src_dataset))
eval(dann, src_data_loader_eval)
print("Evaluating dann for target domain {}".format(params.tgt_dataset))
eval(dann, tgt_data_loader_eval)
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device)

Loading…
Cancel
Save