From 506c25df6d61daa625ec3de8ece3c28d7da4a374 Mon Sep 17 00:00:00 2001 From: fazilaltinel Date: Thu, 10 Dec 2020 15:13:47 +0300 Subject: [PATCH] Changes for new file organization --- README.md | 3 +++ __init__.py | 0 core/__init__.py | 0 experiment.py => core/experiment.py | 18 ++++++++++-------- train_source.py => core/train_source.py | 18 +++++++++++------- trainer.py => core/trainer.py | 16 ++++------------ main.py | 4 ++-- models/__init__.py | 0 models.py => models/models.py | 0 utils/__init__.py | 0 altutils.py => utils/altutils.py | 0 utils.py => utils/utils.py | 0 12 files changed, 30 insertions(+), 29 deletions(-) create mode 100644 __init__.py create mode 100644 core/__init__.py rename experiment.py => core/experiment.py (84%) rename train_source.py => core/train_source.py (82%) rename trainer.py => core/trainer.py (92%) create mode 100644 models/__init__.py rename models.py => models/models.py (100%) create mode 100644 utils/__init__.py rename altutils.py => utils/altutils.py (100%) rename utils.py => utils/utils.py (100%) diff --git a/README.md b/README.md index a2365e4..10fb985 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,9 @@ implement Adversarial Discriminative Domain Adapation in PyTorch This repo is mostly based on https://github.com/Fujiki-Nakamura/ADDA.PyTorch +## Note +Before running the training code, make sure that `DATASETDIR` environment variable is set to dataset directory. + ## Example ``` $ python train_source.py --logdir outputs diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiment.py b/core/experiment.py similarity index 84% rename from experiment.py rename to core/experiment.py index a4ecff9..827f337 100644 --- a/experiment.py +++ b/core/experiment.py @@ -1,14 +1,14 @@ import os - +import sys +sys.path.append(os.path.abspath('.')) import torch from torch import nn, optim from torch.utils.data import DataLoader from torchvision.datasets import SVHN, MNIST from torchvision import transforms - -from models import CNN, Discriminator -from trainer import train_target_cnn -from utils import get_logger +from models.models import CNN, Discriminator +from core.trainer import train_target_cnn +from utils.utils import get_logger def run(args): @@ -17,6 +17,8 @@ def run(args): logger = get_logger(os.path.join(args.logdir, 'main.log')) logger.info(args) + dataset_root = os.environ["DATASETDIR"] + # data source_transform = transforms.Compose([ # transforms.Grayscale(), @@ -28,11 +30,11 @@ def run(args): transforms.Lambda(lambda x: x.repeat(3, 1, 1)) ]) source_dataset_train = SVHN( - './input', 'train', transform=source_transform, download=True) + 'input/', 'train', transform=source_transform, download=True) target_dataset_train = MNIST( - './input', 'train', transform=target_transform, download=True) + 'input/', 'train', transform=target_transform, download=True) target_dataset_test = MNIST( - './input', 'test', transform=target_transform, download=True) + 'input/', 'test', transform=target_transform, download=True) source_train_loader = DataLoader( source_dataset_train, args.batch_size, shuffle=True, drop_last=True, diff --git a/train_source.py b/core/train_source.py similarity index 82% rename from train_source.py rename to core/train_source.py index 0481722..8317685 100644 --- a/train_source.py +++ b/core/train_source.py @@ -1,14 +1,14 @@ import argparse import os - +import sys +sys.path.append(os.path.abspath('.')) from torch import nn, optim from torch.utils.data import DataLoader from torchvision.datasets import SVHN from torchvision import transforms - -from models import CNN -from trainer import train_source_cnn -from utils import get_logger +from models.models import CNN +from core.trainer import train_source_cnn +from utils.utils import get_logger def main(args): @@ -17,14 +17,18 @@ def main(args): logger = get_logger(os.path.join(args.logdir, 'train_source.log')) logger.info(args) + dataset_root = os.environ["DATASETDIR"] + # data source_transform = transforms.Compose([ transforms.ToTensor()] ) + # source_dataset_train = SVHN( + # dataset_root, 'train', transform=source_transform, download=True) source_dataset_train = SVHN( - './input', 'train', transform=source_transform, download=True) + 'input/', 'train', transform=source_transform, download=True) source_dataset_test = SVHN( - './input', 'test', transform=source_transform, download=True) + 'input/', 'test', transform=source_transform, download=True) source_train_loader = DataLoader( source_dataset_train, args.batch_size, shuffle=True, drop_last=True, diff --git a/trainer.py b/core/trainer.py similarity index 92% rename from trainer.py rename to core/trainer.py index 309de24..ca1fe55 100644 --- a/trainer.py +++ b/core/trainer.py @@ -1,12 +1,12 @@ +import os +import sys +sys.path.append(os.path.abspath('.')) from logging import getLogger from time import time - import numpy as np from sklearn.metrics import accuracy_score -from tensorboardX import SummaryWriter import torch - -from utils import AverageMeter, save +from utils.utils import AverageMeter, save logger = getLogger('adda.trainer') @@ -54,7 +54,6 @@ def train_target_cnn( validation = validate(source_cnn, target_test_loader, criterion, args=args) log_source = 'Source/Acc {:.3f} '.format(validation['acc']) - writer = SummaryWriter(args.logdir) best_score = None for epoch_i in range(1, 1 + args.epochs): start_time = time() @@ -91,13 +90,6 @@ def train_target_cnn( } save(args.logdir, state_dict, is_best) - # tensorboard - writer.add_scalar('Adv/D/Loss', training['d/loss'], epoch_i) - writer.add_scalar('Adv/Target/Loss', training['target/loss'], epoch_i) - writer.add_scalar('Val/Target/Loss', validation['loss'], epoch_i) - writer.add_scalar('Val/Target/Acc', validation['acc'], epoch_i) - writer.add_scalar('Train/Target/Acc', validation2['acc'], epoch_i) - def adversarial( source_cnn, target_cnn, discriminator, diff --git a/main.py b/main.py index c47af4f..ac43bee 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,5 @@ import argparse -import experiment +from core.experiment import run if __name__ == '__main__': @@ -21,4 +21,4 @@ if __name__ == '__main__': parser.add_argument('--logdir', type=str, default='outputs/garbage') parser.add_argument('--message', '-m', type=str, default='') args, unknown = parser.parse_known_args() - experiment.run(args) + run(args) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models.py b/models/models.py similarity index 100% rename from models.py rename to models/models.py diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/altutils.py b/utils/altutils.py similarity index 100% rename from altutils.py rename to utils/altutils.py diff --git a/utils.py b/utils/utils.py similarity index 100% rename from utils.py rename to utils/utils.py