Browse Source

update SYNSIGNS-GTSRB source only expe result, finally succeed.

master
wogong 5 years ago
parent
commit
f20a043bf2
  1. 2
      README.md
  2. 6
      experiments/synsigns_gtsrb_src_only.py
  3. 27
      models/model.py
  4. 1
      utils/utils.py

2
README.md

@ -32,7 +32,7 @@ A PyTorch implementation for paper *[Unsupervised Domain Adaptation by Backpropa
| :------------------: | :------------: | :--------: | :-----------: |:-------------: |:-------------: |
| Source Only | 0.5225 | 0.5490 | 0.6420 | 0. | 0. |
| DANN(paper) | 0.7666 | 0.7385 | 0.7300 | 0.9109 | 0.7900 |
| This Repo Source Only| - | - | - | 0. | 0. |
| This Repo Source Only| - | - | - | 0. | 0.7650 |
| This Repo | 0.8400 | 0.7339 | 0.6528 | 0.8200 | 0.6200 |
## Credit

6
experiments/synsigns_gtsrb_src_only.py

@ -7,13 +7,13 @@ import torch
sys.path.append('../')
from models.model import GTSRBmodel
from core.train import train_dann
from utils.utils import get_data_loader, init_model, init_random_seed
from utils.utils import get_data_loader, init_model, init_random_seed, init_weights
class Config(object):
# params for path
model_name = "synsigns-gtsrb"
model_base = '/home/wogong/models/pytorch-dann'
note = 'src-only-40-bn'
note = 'src-only-40-bn-init'
model_root = os.path.join(model_base, model_name, note + '_' + datetime.datetime.now().strftime('%m%d_%H%M%S'))
os.makedirs(model_root)
config = os.path.join(model_root, 'config.txt')
@ -75,6 +75,8 @@ tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.tgt_image_root
# load dann model
dann = init_model(net=GTSRBmodel(), restore=None)
init_weights(dann)
# train dann model
print("Training dann model")

27
models/model.py

@ -210,42 +210,43 @@ class GTSRBmodel(nn.Module):
self.restored = False
self.feature = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=96, kernel_size=(5, 5)), # 36 ; 44
nn.Conv2d(in_channels=3, out_channels=96, kernel_size=(5, 5), stride=1, padding=2), # 36 ; 44
nn.BatchNorm2d(96),
nn.ReLU(inplace=True),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), # 18 ; 22
nn.Conv2d(in_channels=96, out_channels=144, kernel_size=(3, 3)), # 16 ; 20
nn.Conv2d(in_channels=96, out_channels=144, kernel_size=(3, 3), stride=1, padding=1), # 16 ; 20
nn.BatchNorm2d(144),
nn.ReLU(inplace=True),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), # 8 ; 10
nn.Conv2d(in_channels=144, out_channels=256, kernel_size=(5, 5)), # 4 ; 6
nn.Conv2d(in_channels=144, out_channels=256, kernel_size=(5, 5), stride=1, padding=2), # 4 ; 6
nn.BatchNorm2d(256),
nn.Dropout2d(),
nn.ReLU(inplace=True),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), # 2 ; 3
)
self.classifier = nn.Sequential(
nn.Linear(256 * 2 * 2, 512),
nn.Linear(256 * 5 * 5, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.ReLU(),
nn.Dropout(),
nn.Linear(512, 43),
)
self.discriminator = nn.Sequential(
nn.Linear(256 * 2 * 2, 1024),
nn.Linear(256 * 5 * 5, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(inplace=True),
nn.ReLU(),
nn.Linear(1024, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(inplace=True),
nn.ReLU(),
nn.Dropout(),
nn.Linear(1024, 2),
)
def forward(self, input_data, alpha = 1.0):
input_data = input_data.expand(input_data.data.shape[0], 3, 40, 40)
feature = self.feature(input_data)
feature = feature.view(-1, 256 * 2 * 2)
feature = feature.view(-1, 256 * 5 * 5)
reverse_feature = ReverseLayerF.apply(feature, alpha)
class_output = self.classifier(feature)
domain_output = self.discriminator(reverse_feature)

1
utils/utils.py

@ -29,6 +29,7 @@ def init_weights(layer):
layer_name = layer.__class__.__name__
if layer_name.find("Conv") != -1:
layer.weight.data.normal_(0.0, 0.02)
layer.bias.data.fill(0.0)
elif layer_name.find("BatchNorm") != -1:
layer.weight.data.normal_(1.0, 0.02)
layer.bias.data.fill_(0)

Loading…
Cancel
Save