Browse Source

update

master
fnakamura 6 years ago
parent
commit
4ef88ffd6d
  1. 11
      main.py
  2. 16
      models.py
  3. 2
      trainer.py

11
main.py

@ -5,15 +5,16 @@ import experiment
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# NN
parser.add_argument('--in_channels', type=int, default=1)
parser.add_argument('--in_channels', type=int, default=3)
parser.add_argument('--n_classes', type=int, default=10)
parser.add_argument('--trained', type=str, default='')
parser.add_argument('--slope', type=float, default=0.2)
parser.add_argument('--slope', type=float, default=0.1)
# train
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=0.)
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--weight_decay', type=float, default=2.5e-5)
parser.add_argument('--epochs', type=int, default=512)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--betas', type=float, nargs='+', default=(.5, .999))
# misc
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--n_workers', type=int, default=0)

16
models.py

@ -5,14 +5,14 @@ import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, in_channels=1, h=256, dropout=0.5):
super(Encoder, self).__init__()
self.conv1 = nn.Conv2d(in_channels, 8, kernel_size=5, stride=1)
self.conv2 = nn.Conv2d(8, 16, kernel_size=5, stride=1)
self.conv3 = nn.Conv2d(16, 120, kernel_size=4, stride=1)
self.conv1 = nn.Conv2d(in_channels, 20, kernel_size=5, stride=1)
self.conv2 = nn.Conv2d(20, 50, kernel_size=5, stride=1)
# self.conv3 = nn.Conv2d(16, 120, kernel_size=4, stride=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu = nn.ReLU()
self.dropout1 = nn.Dropout2d(dropout)
self.dropout2 = nn.Dropout(dropout)
self.fc = nn.Linear(480, 500)
# self.dropout1 = nn.Dropout2d(dropout)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(1250, 500)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
@ -23,9 +23,9 @@ class Encoder(nn.Module):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
# x = self.dropout1(self.relu(self.conv3(x)))
x = self.relu(self.conv3(x))
# x = self.relu(self.conv3(x))
x = x.view(bs, -1)
x = self.dropout2(self.fc(x))
x = self.dropout(self.fc(x))
return x

2
trainer.py

@ -108,7 +108,7 @@ def adversarial(
D_output_target = discriminator(D_input_target)
d_loss_source = d_criterion(D_output_source, D_target_source)
d_loss_target = d_criterion(D_output_target, D_target_target)
d_loss = 0.5 * (d_loss_source + d_loss_target)
d_loss = d_loss_source + d_loss_target
d_optimizer.zero_grad()
d_loss.backward(retain_graph=True)
d_optimizer.step()

Loading…
Cancel
Save