diff --git a/README.md b/README.md index 605710b..aeee8da 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ $ python main.py --logdir outputs --trained outputs/best_model.pt --slope 0.2 For training on Office dataset using ResNet-50 ``` -$ python core/train_source_rn50.py --n_classes 31 --lr 1e-4 --src_cat amazon --tgt_cat webcam +$ python core/train_source_rn50.py --n_classes 31 --lr 1e-5 --src_cat amazon --tgt_cat webcam $ python main.py --n_classes 31 --trained outputs/garbage/best_model.pt --lr 1e-5 --d_lr 1e-4 --logdir outputs --model resnet50 --src-cat amazon --tgt-cat webcam ``` diff --git a/models/resnet50off.py b/models/resnet50off.py index a8a51bb..ffeca90 100644 --- a/models/resnet50off.py +++ b/models/resnet50off.py @@ -45,7 +45,7 @@ class Encoder(nn.Module): param.requires_grad = False def forward(self, x): - x = x.expand(x.data.shape[0], 3, 227, 227) + x = x.expand(x.data.shape[0], 3, 224, 224) x = self.feature_extractor(x) ### x = self.layer4(x) @@ -60,9 +60,9 @@ class Classifier(nn.Module): super(Classifier, self).__init__() self.l1 = nn.Linear(2048, n_classes) - for m in self.modules(): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): - nn.init.kaiming_normal_(m.weight) + # for m in self.modules(): + # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + # nn.init.kaiming_normal_(m.weight) def forward(self, x): x = self.l1(x) @@ -90,14 +90,20 @@ class Discriminator(nn.Module): self.l1 = nn.Linear(2048, h) self.l2 = nn.Linear(h, h) self.l3 = nn.Linear(h, 2) + ### + self.l4 = nn.LogSoftmax(dim=1) + ### self.slope = args.slope - for m in self.modules(): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): - nn.init.kaiming_normal_(m.weight) + # for m in self.modules(): + # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + # nn.init.kaiming_normal_(m.weight) def forward(self, x): x = F.leaky_relu(self.l1(x), self.slope) x = F.leaky_relu(self.l2(x), self.slope) x = self.l3(x) + ### + x = self.l4(x) + ### return x diff --git a/utils/altutils.py b/utils/altutils.py index ded6777..1bd05b1 100644 --- a/utils/altutils.py +++ b/utils/altutils.py @@ -40,7 +40,7 @@ def get_office(dataset_root, batch_size, category): """Get Office datasets loader.""" # image pre-processing pre_process = transforms.Compose([ - transforms.Resize(227), + transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ])