|
|
@ -32,10 +32,6 @@ class Encoder(nn.Module): |
|
|
|
def __init__(self, in_channels=3, h=256, dropout=0.5, srcTrain=False): |
|
|
|
super(Encoder, self).__init__() |
|
|
|
|
|
|
|
# resnetModel = models.resnet50(pretrained=True) |
|
|
|
# feature_map = list(resnetModel.children()) |
|
|
|
# feature_map.pop() |
|
|
|
# self.feature_extractor = nn.Sequential(*feature_map) |
|
|
|
rnMod = ResNet50Mod() |
|
|
|
self.feature_extractor = rnMod.freezed_rn50 |
|
|
|
self.layer4 = rnMod.layer4 |
|
|
@ -47,10 +43,8 @@ class Encoder(nn.Module): |
|
|
|
def forward(self, x): |
|
|
|
x = x.expand(x.data.shape[0], 3, 224, 224) |
|
|
|
x = self.feature_extractor(x) |
|
|
|
### |
|
|
|
x = self.layer4(x) |
|
|
|
x = self.avgpool(x) |
|
|
|
### |
|
|
|
x = x.view(x.size(0), -1) |
|
|
|
return x |
|
|
|
|
|
|
@ -60,10 +54,6 @@ class Classifier(nn.Module): |
|
|
|
super(Classifier, self).__init__() |
|
|
|
self.l1 = nn.Linear(2048, n_classes) |
|
|
|
|
|
|
|
# for m in self.modules(): |
|
|
|
# if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
|
|
|
# nn.init.kaiming_normal_(m.weight) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
x = self.l1(x) |
|
|
|
return x |
|
|
@ -90,20 +80,12 @@ class Discriminator(nn.Module): |
|
|
|
self.l1 = nn.Linear(2048, h) |
|
|
|
self.l2 = nn.Linear(h, h) |
|
|
|
self.l3 = nn.Linear(h, 2) |
|
|
|
### |
|
|
|
self.l4 = nn.LogSoftmax(dim=1) |
|
|
|
### |
|
|
|
self.slope = args.slope |
|
|
|
|
|
|
|
# for m in self.modules(): |
|
|
|
# if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
|
|
|
# nn.init.kaiming_normal_(m.weight) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
x = F.leaky_relu(self.l1(x), self.slope) |
|
|
|
x = F.leaky_relu(self.l2(x), self.slope) |
|
|
|
x = self.l3(x) |
|
|
|
### |
|
|
|
x = self.l4(x) |
|
|
|
### |
|
|
|
return x |
|
|
|