|
@ -45,7 +45,7 @@ class Encoder(nn.Module): |
|
|
param.requires_grad = False |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
def forward(self, x): |
|
|
x = x.expand(x.data.shape[0], 3, 227, 227) |
|
|
|
|
|
|
|
|
x = x.expand(x.data.shape[0], 3, 224, 224) |
|
|
x = self.feature_extractor(x) |
|
|
x = self.feature_extractor(x) |
|
|
### |
|
|
### |
|
|
x = self.layer4(x) |
|
|
x = self.layer4(x) |
|
@ -60,9 +60,9 @@ class Classifier(nn.Module): |
|
|
super(Classifier, self).__init__() |
|
|
super(Classifier, self).__init__() |
|
|
self.l1 = nn.Linear(2048, n_classes) |
|
|
self.l1 = nn.Linear(2048, n_classes) |
|
|
|
|
|
|
|
|
for m in self.modules(): |
|
|
|
|
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
|
|
|
|
|
nn.init.kaiming_normal_(m.weight) |
|
|
|
|
|
|
|
|
# for m in self.modules(): |
|
|
|
|
|
# if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
|
|
|
|
|
# nn.init.kaiming_normal_(m.weight) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
def forward(self, x): |
|
|
x = self.l1(x) |
|
|
x = self.l1(x) |
|
@ -90,14 +90,20 @@ class Discriminator(nn.Module): |
|
|
self.l1 = nn.Linear(2048, h) |
|
|
self.l1 = nn.Linear(2048, h) |
|
|
self.l2 = nn.Linear(h, h) |
|
|
self.l2 = nn.Linear(h, h) |
|
|
self.l3 = nn.Linear(h, 2) |
|
|
self.l3 = nn.Linear(h, 2) |
|
|
|
|
|
### |
|
|
|
|
|
self.l4 = nn.LogSoftmax(dim=1) |
|
|
|
|
|
### |
|
|
self.slope = args.slope |
|
|
self.slope = args.slope |
|
|
|
|
|
|
|
|
for m in self.modules(): |
|
|
|
|
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
|
|
|
|
|
nn.init.kaiming_normal_(m.weight) |
|
|
|
|
|
|
|
|
# for m in self.modules(): |
|
|
|
|
|
# if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
|
|
|
|
|
# nn.init.kaiming_normal_(m.weight) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
def forward(self, x): |
|
|
x = F.leaky_relu(self.l1(x), self.slope) |
|
|
x = F.leaky_relu(self.l1(x), self.slope) |
|
|
x = F.leaky_relu(self.l2(x), self.slope) |
|
|
x = F.leaky_relu(self.l2(x), self.slope) |
|
|
x = self.l3(x) |
|
|
x = self.l3(x) |
|
|
|
|
|
### |
|
|
|
|
|
x = self.l4(x) |
|
|
|
|
|
### |
|
|
return x |
|
|
return x |
|
|