|
|
@ -5,14 +5,14 @@ import torch.nn.functional as F |
|
|
|
class Encoder(nn.Module): |
|
|
|
def __init__(self, in_channels=1, h=256, dropout=0.5): |
|
|
|
super(Encoder, self).__init__() |
|
|
|
self.conv1 = nn.Conv2d(in_channels, 8, kernel_size=5, stride=1) |
|
|
|
self.conv2 = nn.Conv2d(8, 16, kernel_size=5, stride=1) |
|
|
|
self.conv3 = nn.Conv2d(16, 120, kernel_size=4, stride=1) |
|
|
|
self.conv1 = nn.Conv2d(in_channels, 20, kernel_size=5, stride=1) |
|
|
|
self.conv2 = nn.Conv2d(20, 50, kernel_size=5, stride=1) |
|
|
|
# self.conv3 = nn.Conv2d(16, 120, kernel_size=4, stride=1) |
|
|
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
self.relu = nn.ReLU() |
|
|
|
self.dropout1 = nn.Dropout2d(dropout) |
|
|
|
self.dropout2 = nn.Dropout(dropout) |
|
|
|
self.fc = nn.Linear(480, 500) |
|
|
|
# self.dropout1 = nn.Dropout2d(dropout) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
self.fc = nn.Linear(1250, 500) |
|
|
|
|
|
|
|
for m in self.modules(): |
|
|
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
|
|
@ -23,9 +23,9 @@ class Encoder(nn.Module): |
|
|
|
x = self.pool(self.relu(self.conv1(x))) |
|
|
|
x = self.pool(self.relu(self.conv2(x))) |
|
|
|
# x = self.dropout1(self.relu(self.conv3(x))) |
|
|
|
x = self.relu(self.conv3(x)) |
|
|
|
# x = self.relu(self.conv3(x)) |
|
|
|
x = x.view(bs, -1) |
|
|
|
x = self.dropout2(self.fc(x)) |
|
|
|
x = self.dropout(self.fc(x)) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|