From a77b1b1f3eb7b6cf57aafc1a01e0bf27660178e0 Mon Sep 17 00:00:00 2001 From: fnakamura Date: Wed, 20 Feb 2019 19:44:58 +0900 Subject: [PATCH] refactor --- experiment.py | 23 ++++-------------- train_source.py | 62 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 18 deletions(-) create mode 100644 train_source.py diff --git a/experiment.py b/experiment.py index 6e532d8..785100b 100644 --- a/experiment.py +++ b/experiment.py @@ -26,8 +26,6 @@ def run(args): ]) source_dataset_train = SVHN( './input', 'train', transform=source_transform, download=True) - source_dataset_test = SVHN( - './input', 'test', transform=source_transform, download=True) target_dataset_train = MNIST( './input', 'train', transform=target_transform, download=True) target_dataset_test = MNIST( @@ -36,9 +34,6 @@ def run(args): source_dataset_train, args.batch_size, shuffle=True, drop_last=True, num_workers=args.n_workers) - source_test_loader = DataLoader( - source_dataset_test, args.batch_size, shuffle=False, - num_workers=args.n_workers) target_train_loader = DataLoader( target_dataset_train, args.batch_size, shuffle=True, drop_last=True, @@ -52,27 +47,19 @@ def run(args): if os.path.isfile(args.trained): c = torch.load(args.trained) source_cnn.load_state_dict(c['model']) - print('Loaded `{}`'.format(args.trained)) - else: - criterion = nn.CrossEntropyLoss() - optimizer = optim.Adam( - source_cnn.parameters(), - lr=args.lr, weight_decay=args.weight_decay) - source_cnn = train_source_cnn( - source_cnn, source_train_loader, source_test_loader, - criterion, optimizer, args=args) + logger.info('Loaded `{}`'.format(args.trained)) # train target CNN target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device) target_cnn.load_state_dict(source_cnn.state_dict()) discriminator = Discriminator(args=args).to(args.device) criterion = nn.CrossEntropyLoss() - optimizer = optim.RMSprop( # optim.Adam( + optimizer = optim.Adam( target_cnn.encoder.parameters(), - lr=args.lr, weight_decay=args.weight_decay) - d_optimizer = optim.RMSprop( # optim.Adam( + lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) + d_optimizer = optim.Adam( discriminator.parameters(), - lr=args.lr, weight_decay=args.weight_decay) + lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) train_target_cnn( source_cnn, target_cnn, discriminator, criterion, optimizer, d_optimizer, diff --git a/train_source.py b/train_source.py new file mode 100644 index 0000000..edcd74c --- /dev/null +++ b/train_source.py @@ -0,0 +1,62 @@ +import argparse +import os + +from torch import nn, optim +from torch.utils.data import DataLoader +from torchvision.datasets import SVHN +from torchvision import transforms + +from models import CNN +from trainer import train_source_cnn + + +def main(args): + if not os.path.exists(args.logdir): + os.makedirs(args.logdir) + + # data + source_transform = transforms.Compose([ + transforms.ToTensor()] + ) + source_dataset_train = SVHN( + './input', 'train', transform=source_transform, download=True) + source_dataset_test = SVHN( + './input', 'test', transform=source_transform, download=True) + source_train_loader = DataLoader( + source_dataset_train, args.batch_size, shuffle=True, + drop_last=True, + num_workers=args.n_workers) + source_test_loader = DataLoader( + source_dataset_test, args.batch_size, shuffle=False, + num_workers=args.n_workers) + + # train source CNN + source_cnn = CNN(in_channels=args.in_channels).to(args.device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam( + source_cnn.parameters(), + lr=args.lr, weight_decay=args.weight_decay) + source_cnn = train_source_cnn( + source_cnn, source_train_loader, source_test_loader, + criterion, optimizer, args=args) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # NN + parser.add_argument('--in_channels', type=int, default=3) + parser.add_argument('--n_classes', type=int, default=10) + parser.add_argument('--trained', type=str, default='') + parser.add_argument('--slope', type=float, default=0.2) + # train + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--weight_decay', type=float, default=2.5e-5) + parser.add_argument('--epochs', type=int, default=50) + parser.add_argument('--batch_size', type=int, default=128) + # misc + parser.add_argument('--device', type=str, default='cuda:0') + parser.add_argument('--n_workers', type=int, default=0) + parser.add_argument('--logdir', type=str, default='outputs/garbage') + parser.add_argument('--message', '-m', type=str, default='') + args, unknown = parser.parse_known_args() + main(args)