Browse Source

minor update

master
wogong 5 years ago
parent
commit
702bea05d5
  1. 4
      README.md
  2. 13
      experiments/mnist_mnistm.py
  3. 3
      experiments/svhn_mnist.py

4
README.md

@ -14,12 +14,12 @@ A PyTorch implementation for paper *[Unsupervised Domain Adaptation by Backpropa
## Note ## Note
- `Config()` 为针对特定任务的配置参数
- `Config()` 为针对特定任务的配置参数
- `MNISTmodel()` 完全按照论文中的结构,但是 feature 部分添加了 `Dropout2d()`,实验发现是否添加 `Dropout2d()` 对于最后的性能影响很大。最后实验重现结果高于论文,因为使用了额外的技巧,这里还有值得探究的地方。 - `MNISTmodel()` 完全按照论文中的结构,但是 feature 部分添加了 `Dropout2d()`,实验发现是否添加 `Dropout2d()` 对于最后的性能影响很大。最后实验重现结果高于论文,因为使用了额外的技巧,这里还有值得探究的地方。
- `SVHNmodel()` 无法理解论文中提出的结构,为自定义结构。最后实验重现结果完美。 - `SVHNmodel()` 无法理解论文中提出的结构,为自定义结构。最后实验重现结果完美。
- MNIST-MNISTM: `python mnist_mnistm.py` - MNIST-MNISTM: `python mnist_mnistm.py`
- SVHN-MNIST: `python svhn_mnist.py` - SVHN-MNIST: `python svhn_mnist.py`
- Amazon-Webcam: `python office.py` 没有复现成功
- Amazon-Webcam: `python office.py` 由于预训练网络的问题,无法复现
## Result ## Result

13
experiments/mnist_mnistm.py

@ -1,8 +1,9 @@
import os import os
import sys import sys
import torch
sys.path.append('../') sys.path.append('../')
from models.model import MNISTmodel
from models.model import MNISTmodel, MNISTmodel_plain
from core.dann import train_dann from core.dann import train_dann
from utils.utils import get_data_loader, init_model, init_random_seed from utils.utils import get_data_loader, init_model, init_random_seed
@ -13,7 +14,7 @@ class Config(object):
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN')) model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN'))
# params for datasets and data loader # params for datasets and data loader
batch_size = 128
batch_size = 64
# params for source dataset # params for source dataset
src_dataset = "mnist" src_dataset = "mnist"
@ -33,6 +34,7 @@ class Config(object):
eval_step_src = 20 eval_step_src = 20
# params for training dann # params for training dann
gpu_id = '0'
## for digit ## for digit
num_epochs = 100 num_epochs = 100
@ -58,6 +60,9 @@ params = Config()
# init random seed # init random seed
init_random_seed(params.manual_seed) init_random_seed(params.manual_seed)
# init device
device = torch.device("cuda:" + params.gpu_id if torch.cuda.is_available() else "cpu")
# load dataset # load dataset
src_data_loader = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=True) src_data_loader = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=True)
src_data_loader_eval = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=False) src_data_loader_eval = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=False)
@ -65,9 +70,9 @@ tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root, param
tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=False) tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size, train=False)
# load dann model # load dann model
dann = init_model(net=MNISTmodel(), restore=None)
dann = init_model(net=MNISTmodel_plain(), restore=None)
# train dann model # train dann model
print("Training dann model") print("Training dann model")
if not (dann.restored and params.dann_restore): if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval)
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, device)

3
experiments/svhn_mnist.py

@ -6,8 +6,6 @@ import torch
sys.path.append('../') sys.path.append('../')
from models.model import SVHNmodel from models.model import SVHNmodel
from core.dann import train_dann from core.dann import train_dann
from core.pretrain import train_src
from core.test import test
from utils.utils import get_data_loader, init_model, init_random_seed from utils.utils import get_data_loader, init_model, init_random_seed
@ -57,7 +55,6 @@ class Config(object):
# params for optimizing models # params for optimizing models
lr = 2e-4 lr = 2e-4
params = Config() params = Config()
# init random seed # init random seed

Loading…
Cancel
Save