Browse Source

Performance improvements on ResNet

master
Fazil Altinel 4 years ago
parent
commit
276be62452
  1. 2
      README.md
  2. 20
      models/resnet50off.py
  3. 2
      utils/altutils.py

2
README.md

@ -19,7 +19,7 @@ $ python main.py --logdir outputs --trained outputs/best_model.pt --slope 0.2
For training on Office dataset using ResNet-50
```
$ python core/train_source_rn50.py --n_classes 31 --lr 1e-4 --src_cat amazon --tgt_cat webcam
$ python core/train_source_rn50.py --n_classes 31 --lr 1e-5 --src_cat amazon --tgt_cat webcam
$ python main.py --n_classes 31 --trained outputs/garbage/best_model.pt --lr 1e-5 --d_lr 1e-4 --logdir outputs --model resnet50 --src-cat amazon --tgt-cat webcam
```

20
models/resnet50off.py

@ -45,7 +45,7 @@ class Encoder(nn.Module):
param.requires_grad = False
def forward(self, x):
x = x.expand(x.data.shape[0], 3, 227, 227)
x = x.expand(x.data.shape[0], 3, 224, 224)
x = self.feature_extractor(x)
###
x = self.layer4(x)
@ -60,9 +60,9 @@ class Classifier(nn.Module):
super(Classifier, self).__init__()
self.l1 = nn.Linear(2048, n_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
# for m in self.modules():
# if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
# nn.init.kaiming_normal_(m.weight)
def forward(self, x):
x = self.l1(x)
@ -90,14 +90,20 @@ class Discriminator(nn.Module):
self.l1 = nn.Linear(2048, h)
self.l2 = nn.Linear(h, h)
self.l3 = nn.Linear(h, 2)
###
self.l4 = nn.LogSoftmax(dim=1)
###
self.slope = args.slope
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
# for m in self.modules():
# if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
# nn.init.kaiming_normal_(m.weight)
def forward(self, x):
x = F.leaky_relu(self.l1(x), self.slope)
x = F.leaky_relu(self.l2(x), self.slope)
x = self.l3(x)
###
x = self.l4(x)
###
return x

2
utils/altutils.py

@ -40,7 +40,7 @@ def get_office(dataset_root, batch_size, category):
"""Get Office datasets loader."""
# image pre-processing
pre_process = transforms.Compose([
transforms.Resize(227),
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

Loading…
Cancel
Save