Browse Source

Changes for new file organization

master
Fazil Altinel 4 years ago
parent
commit
506c25df6d
  1. 3
      README.md
  2. 0
      __init__.py
  3. 0
      core/__init__.py
  4. 18
      core/experiment.py
  5. 18
      core/train_source.py
  6. 16
      core/trainer.py
  7. 4
      main.py
  8. 0
      models/__init__.py
  9. 0
      models/models.py
  10. 0
      utils/__init__.py
  11. 0
      utils/altutils.py
  12. 0
      utils/utils.py

3
README.md

@ -3,6 +3,9 @@ implement Adversarial Discriminative Domain Adapation in PyTorch
This repo is mostly based on https://github.com/Fujiki-Nakamura/ADDA.PyTorch
## Note
Before running the training code, make sure that `DATASETDIR` environment variable is set to dataset directory.
## Example
```
$ python train_source.py --logdir outputs

0
__init__.py

0
core/__init__.py

18
experiment.py → core/experiment.py

@ -1,14 +1,14 @@
import os
import sys
sys.path.append(os.path.abspath('.'))
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import SVHN, MNIST
from torchvision import transforms
from models import CNN, Discriminator
from trainer import train_target_cnn
from utils import get_logger
from models.models import CNN, Discriminator
from core.trainer import train_target_cnn
from utils.utils import get_logger
def run(args):
@ -17,6 +17,8 @@ def run(args):
logger = get_logger(os.path.join(args.logdir, 'main.log'))
logger.info(args)
dataset_root = os.environ["DATASETDIR"]
# data
source_transform = transforms.Compose([
# transforms.Grayscale(),
@ -28,11 +30,11 @@ def run(args):
transforms.Lambda(lambda x: x.repeat(3, 1, 1))
])
source_dataset_train = SVHN(
'./input', 'train', transform=source_transform, download=True)
'input/', 'train', transform=source_transform, download=True)
target_dataset_train = MNIST(
'./input', 'train', transform=target_transform, download=True)
'input/', 'train', transform=target_transform, download=True)
target_dataset_test = MNIST(
'./input', 'test', transform=target_transform, download=True)
'input/', 'test', transform=target_transform, download=True)
source_train_loader = DataLoader(
source_dataset_train, args.batch_size, shuffle=True,
drop_last=True,

18
train_source.py → core/train_source.py

@ -1,14 +1,14 @@
import argparse
import os
import sys
sys.path.append(os.path.abspath('.'))
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import SVHN
from torchvision import transforms
from models import CNN
from trainer import train_source_cnn
from utils import get_logger
from models.models import CNN
from core.trainer import train_source_cnn
from utils.utils import get_logger
def main(args):
@ -17,14 +17,18 @@ def main(args):
logger = get_logger(os.path.join(args.logdir, 'train_source.log'))
logger.info(args)
dataset_root = os.environ["DATASETDIR"]
# data
source_transform = transforms.Compose([
transforms.ToTensor()]
)
# source_dataset_train = SVHN(
# dataset_root, 'train', transform=source_transform, download=True)
source_dataset_train = SVHN(
'./input', 'train', transform=source_transform, download=True)
'input/', 'train', transform=source_transform, download=True)
source_dataset_test = SVHN(
'./input', 'test', transform=source_transform, download=True)
'input/', 'test', transform=source_transform, download=True)
source_train_loader = DataLoader(
source_dataset_train, args.batch_size, shuffle=True,
drop_last=True,

16
trainer.py → core/trainer.py

@ -1,12 +1,12 @@
import os
import sys
sys.path.append(os.path.abspath('.'))
from logging import getLogger
from time import time
import numpy as np
from sklearn.metrics import accuracy_score
from tensorboardX import SummaryWriter
import torch
from utils import AverageMeter, save
from utils.utils import AverageMeter, save
logger = getLogger('adda.trainer')
@ -54,7 +54,6 @@ def train_target_cnn(
validation = validate(source_cnn, target_test_loader, criterion, args=args)
log_source = 'Source/Acc {:.3f} '.format(validation['acc'])
writer = SummaryWriter(args.logdir)
best_score = None
for epoch_i in range(1, 1 + args.epochs):
start_time = time()
@ -91,13 +90,6 @@ def train_target_cnn(
}
save(args.logdir, state_dict, is_best)
# tensorboard
writer.add_scalar('Adv/D/Loss', training['d/loss'], epoch_i)
writer.add_scalar('Adv/Target/Loss', training['target/loss'], epoch_i)
writer.add_scalar('Val/Target/Loss', validation['loss'], epoch_i)
writer.add_scalar('Val/Target/Acc', validation['acc'], epoch_i)
writer.add_scalar('Train/Target/Acc', validation2['acc'], epoch_i)
def adversarial(
source_cnn, target_cnn, discriminator,

4
main.py

@ -1,5 +1,5 @@
import argparse
import experiment
from core.experiment import run
if __name__ == '__main__':
@ -21,4 +21,4 @@ if __name__ == '__main__':
parser.add_argument('--logdir', type=str, default='outputs/garbage')
parser.add_argument('--message', '-m', type=str, default='')
args, unknown = parser.parse_known_args()
experiment.run(args)
run(args)

0
models/__init__.py

0
models.py → models/models.py

0
utils/__init__.py

0
altutils.py → utils/altutils.py

0
utils.py → utils/utils.py

Loading…
Cancel
Save