Browse Source

Changes for new file organization

master
Fazil Altinel 4 years ago
parent
commit
506c25df6d
  1. 3
      README.md
  2. 0
      __init__.py
  3. 0
      core/__init__.py
  4. 18
      core/experiment.py
  5. 18
      core/train_source.py
  6. 16
      core/trainer.py
  7. 4
      main.py
  8. 0
      models/__init__.py
  9. 0
      models/models.py
  10. 0
      utils/__init__.py
  11. 0
      utils/altutils.py
  12. 0
      utils/utils.py

3
README.md

@ -3,6 +3,9 @@ implement Adversarial Discriminative Domain Adapation in PyTorch
This repo is mostly based on https://github.com/Fujiki-Nakamura/ADDA.PyTorch This repo is mostly based on https://github.com/Fujiki-Nakamura/ADDA.PyTorch
## Note
Before running the training code, make sure that `DATASETDIR` environment variable is set to dataset directory.
## Example ## Example
``` ```
$ python train_source.py --logdir outputs $ python train_source.py --logdir outputs

0
__init__.py

0
core/__init__.py

18
experiment.py → core/experiment.py

@ -1,14 +1,14 @@
import os import os
import sys
sys.path.append(os.path.abspath('.'))
import torch import torch
from torch import nn, optim from torch import nn, optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision.datasets import SVHN, MNIST from torchvision.datasets import SVHN, MNIST
from torchvision import transforms from torchvision import transforms
from models import CNN, Discriminator
from trainer import train_target_cnn
from utils import get_logger
from models.models import CNN, Discriminator
from core.trainer import train_target_cnn
from utils.utils import get_logger
def run(args): def run(args):
@ -17,6 +17,8 @@ def run(args):
logger = get_logger(os.path.join(args.logdir, 'main.log')) logger = get_logger(os.path.join(args.logdir, 'main.log'))
logger.info(args) logger.info(args)
dataset_root = os.environ["DATASETDIR"]
# data # data
source_transform = transforms.Compose([ source_transform = transforms.Compose([
# transforms.Grayscale(), # transforms.Grayscale(),
@ -28,11 +30,11 @@ def run(args):
transforms.Lambda(lambda x: x.repeat(3, 1, 1)) transforms.Lambda(lambda x: x.repeat(3, 1, 1))
]) ])
source_dataset_train = SVHN( source_dataset_train = SVHN(
'./input', 'train', transform=source_transform, download=True)
'input/', 'train', transform=source_transform, download=True)
target_dataset_train = MNIST( target_dataset_train = MNIST(
'./input', 'train', transform=target_transform, download=True)
'input/', 'train', transform=target_transform, download=True)
target_dataset_test = MNIST( target_dataset_test = MNIST(
'./input', 'test', transform=target_transform, download=True)
'input/', 'test', transform=target_transform, download=True)
source_train_loader = DataLoader( source_train_loader = DataLoader(
source_dataset_train, args.batch_size, shuffle=True, source_dataset_train, args.batch_size, shuffle=True,
drop_last=True, drop_last=True,

18
train_source.py → core/train_source.py

@ -1,14 +1,14 @@
import argparse import argparse
import os import os
import sys
sys.path.append(os.path.abspath('.'))
from torch import nn, optim from torch import nn, optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision.datasets import SVHN from torchvision.datasets import SVHN
from torchvision import transforms from torchvision import transforms
from models import CNN
from trainer import train_source_cnn
from utils import get_logger
from models.models import CNN
from core.trainer import train_source_cnn
from utils.utils import get_logger
def main(args): def main(args):
@ -17,14 +17,18 @@ def main(args):
logger = get_logger(os.path.join(args.logdir, 'train_source.log')) logger = get_logger(os.path.join(args.logdir, 'train_source.log'))
logger.info(args) logger.info(args)
dataset_root = os.environ["DATASETDIR"]
# data # data
source_transform = transforms.Compose([ source_transform = transforms.Compose([
transforms.ToTensor()] transforms.ToTensor()]
) )
# source_dataset_train = SVHN(
# dataset_root, 'train', transform=source_transform, download=True)
source_dataset_train = SVHN( source_dataset_train = SVHN(
'./input', 'train', transform=source_transform, download=True)
'input/', 'train', transform=source_transform, download=True)
source_dataset_test = SVHN( source_dataset_test = SVHN(
'./input', 'test', transform=source_transform, download=True)
'input/', 'test', transform=source_transform, download=True)
source_train_loader = DataLoader( source_train_loader = DataLoader(
source_dataset_train, args.batch_size, shuffle=True, source_dataset_train, args.batch_size, shuffle=True,
drop_last=True, drop_last=True,

16
trainer.py → core/trainer.py

@ -1,12 +1,12 @@
import os
import sys
sys.path.append(os.path.abspath('.'))
from logging import getLogger from logging import getLogger
from time import time from time import time
import numpy as np import numpy as np
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from tensorboardX import SummaryWriter
import torch import torch
from utils import AverageMeter, save
from utils.utils import AverageMeter, save
logger = getLogger('adda.trainer') logger = getLogger('adda.trainer')
@ -54,7 +54,6 @@ def train_target_cnn(
validation = validate(source_cnn, target_test_loader, criterion, args=args) validation = validate(source_cnn, target_test_loader, criterion, args=args)
log_source = 'Source/Acc {:.3f} '.format(validation['acc']) log_source = 'Source/Acc {:.3f} '.format(validation['acc'])
writer = SummaryWriter(args.logdir)
best_score = None best_score = None
for epoch_i in range(1, 1 + args.epochs): for epoch_i in range(1, 1 + args.epochs):
start_time = time() start_time = time()
@ -91,13 +90,6 @@ def train_target_cnn(
} }
save(args.logdir, state_dict, is_best) save(args.logdir, state_dict, is_best)
# tensorboard
writer.add_scalar('Adv/D/Loss', training['d/loss'], epoch_i)
writer.add_scalar('Adv/Target/Loss', training['target/loss'], epoch_i)
writer.add_scalar('Val/Target/Loss', validation['loss'], epoch_i)
writer.add_scalar('Val/Target/Acc', validation['acc'], epoch_i)
writer.add_scalar('Train/Target/Acc', validation2['acc'], epoch_i)
def adversarial( def adversarial(
source_cnn, target_cnn, discriminator, source_cnn, target_cnn, discriminator,

4
main.py

@ -1,5 +1,5 @@
import argparse import argparse
import experiment
from core.experiment import run
if __name__ == '__main__': if __name__ == '__main__':
@ -21,4 +21,4 @@ if __name__ == '__main__':
parser.add_argument('--logdir', type=str, default='outputs/garbage') parser.add_argument('--logdir', type=str, default='outputs/garbage')
parser.add_argument('--message', '-m', type=str, default='') parser.add_argument('--message', '-m', type=str, default='')
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
experiment.run(args)
run(args)

0
models/__init__.py

0
models.py → models/models.py

0
utils/__init__.py

0
altutils.py → utils/altutils.py

0
utils.py → utils/utils.py

Loading…
Cancel
Save