From f20a043bf2ef496e9d6c7fdde06b721e31790d12 Mon Sep 17 00:00:00 2001 From: wogong Date: Wed, 30 Oct 2019 10:08:23 +0800 Subject: [PATCH] update SYNSIGNS-GTSRB source only expe result, finally succeed. --- README.md | 2 +- experiments/synsigns_gtsrb_src_only.py | 6 ++++-- models/model.py | 27 +++++++++++++------------- utils/utils.py | 1 + 4 files changed, 20 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index e4c65f8..def1508 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ A PyTorch implementation for paper *[Unsupervised Domain Adaptation by Backpropa | :------------------: | :------------: | :--------: | :-----------: |:-------------: |:-------------: | | Source Only | 0.5225 | 0.5490 | 0.6420 | 0. | 0. | | DANN(paper) | 0.7666 | 0.7385 | 0.7300 | 0.9109 | 0.7900 | -| This Repo Source Only| - | - | - | 0. | 0. | +| This Repo Source Only| - | - | - | 0. | 0.7650 | | This Repo | 0.8400 | 0.7339 | 0.6528 | 0.8200 | 0.6200 | ## Credit diff --git a/experiments/synsigns_gtsrb_src_only.py b/experiments/synsigns_gtsrb_src_only.py index d16d7e7..c296b57 100644 --- a/experiments/synsigns_gtsrb_src_only.py +++ b/experiments/synsigns_gtsrb_src_only.py @@ -7,13 +7,13 @@ import torch sys.path.append('../') from models.model import GTSRBmodel from core.train import train_dann -from utils.utils import get_data_loader, init_model, init_random_seed +from utils.utils import get_data_loader, init_model, init_random_seed, init_weights class Config(object): # params for path model_name = "synsigns-gtsrb" model_base = '/home/wogong/models/pytorch-dann' - note = 'src-only-40-bn' + note = 'src-only-40-bn-init' model_root = os.path.join(model_base, model_name, note + '_' + datetime.datetime.now().strftime('%m%d_%H%M%S')) os.makedirs(model_root) config = os.path.join(model_root, 'config.txt') @@ -75,6 +75,8 @@ tgt_data_loader_eval = get_data_loader(params.tgt_dataset, params.tgt_image_root # load dann model dann = init_model(net=GTSRBmodel(), restore=None) +init_weights(dann) + # train dann model print("Training dann model") diff --git a/models/model.py b/models/model.py index 818819f..9f3fc6a 100644 --- a/models/model.py +++ b/models/model.py @@ -210,42 +210,43 @@ class GTSRBmodel(nn.Module): self.restored = False self.feature = nn.Sequential( - nn.Conv2d(in_channels=3, out_channels=96, kernel_size=(5, 5)), # 36 ; 44 + nn.Conv2d(in_channels=3, out_channels=96, kernel_size=(5, 5), stride=1, padding=2), # 36 ; 44 nn.BatchNorm2d(96), - nn.ReLU(inplace=True), + nn.ReLU(), nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), # 18 ; 22 - nn.Conv2d(in_channels=96, out_channels=144, kernel_size=(3, 3)), # 16 ; 20 + nn.Conv2d(in_channels=96, out_channels=144, kernel_size=(3, 3), stride=1, padding=1), # 16 ; 20 nn.BatchNorm2d(144), - nn.ReLU(inplace=True), + nn.ReLU(), nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), # 8 ; 10 - nn.Conv2d(in_channels=144, out_channels=256, kernel_size=(5, 5)), # 4 ; 6 + nn.Conv2d(in_channels=144, out_channels=256, kernel_size=(5, 5), stride=1, padding=2), # 4 ; 6 nn.BatchNorm2d(256), - nn.Dropout2d(), - nn.ReLU(inplace=True), + nn.ReLU(), nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), # 2 ; 3 ) self.classifier = nn.Sequential( - nn.Linear(256 * 2 * 2, 512), + nn.Linear(256 * 5 * 5, 512), nn.BatchNorm1d(512), - nn.ReLU(inplace=True), + nn.ReLU(), + nn.Dropout(), nn.Linear(512, 43), ) self.discriminator = nn.Sequential( - nn.Linear(256 * 2 * 2, 1024), + nn.Linear(256 * 5 * 5, 1024), nn.BatchNorm1d(1024), - nn.ReLU(inplace=True), + nn.ReLU(), nn.Linear(1024, 1024), nn.BatchNorm1d(1024), - nn.ReLU(inplace=True), + nn.ReLU(), + nn.Dropout(), nn.Linear(1024, 2), ) def forward(self, input_data, alpha = 1.0): input_data = input_data.expand(input_data.data.shape[0], 3, 40, 40) feature = self.feature(input_data) - feature = feature.view(-1, 256 * 2 * 2) + feature = feature.view(-1, 256 * 5 * 5) reverse_feature = ReverseLayerF.apply(feature, alpha) class_output = self.classifier(feature) domain_output = self.discriminator(reverse_feature) diff --git a/utils/utils.py b/utils/utils.py index d30ad91..324d2a9 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -29,6 +29,7 @@ def init_weights(layer): layer_name = layer.__class__.__name__ if layer_name.find("Conv") != -1: layer.weight.data.normal_(0.0, 0.02) + layer.bias.data.fill(0.0) elif layer_name.find("BatchNorm") != -1: layer.weight.data.normal_(1.0, 0.02) layer.bias.data.fill_(0)