|
@ -1,14 +1,14 @@ |
|
|
import os |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
|
|
|
|
sys.path.append(os.path.abspath('.')) |
|
|
import torch |
|
|
import torch |
|
|
from torch import nn, optim |
|
|
from torch import nn, optim |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.utils.data import DataLoader |
|
|
from torchvision.datasets import SVHN, MNIST |
|
|
from torchvision.datasets import SVHN, MNIST |
|
|
from torchvision import transforms |
|
|
from torchvision import transforms |
|
|
|
|
|
|
|
|
from models import CNN, Discriminator |
|
|
|
|
|
from trainer import train_target_cnn |
|
|
|
|
|
from utils import get_logger |
|
|
|
|
|
|
|
|
from models.models import CNN, Discriminator |
|
|
|
|
|
from core.trainer import train_target_cnn |
|
|
|
|
|
from utils.utils import get_logger |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run(args): |
|
|
def run(args): |
|
@ -17,6 +17,8 @@ def run(args): |
|
|
logger = get_logger(os.path.join(args.logdir, 'main.log')) |
|
|
logger = get_logger(os.path.join(args.logdir, 'main.log')) |
|
|
logger.info(args) |
|
|
logger.info(args) |
|
|
|
|
|
|
|
|
|
|
|
dataset_root = os.environ["DATASETDIR"] |
|
|
|
|
|
|
|
|
# data |
|
|
# data |
|
|
source_transform = transforms.Compose([ |
|
|
source_transform = transforms.Compose([ |
|
|
# transforms.Grayscale(), |
|
|
# transforms.Grayscale(), |
|
@ -28,11 +30,11 @@ def run(args): |
|
|
transforms.Lambda(lambda x: x.repeat(3, 1, 1)) |
|
|
transforms.Lambda(lambda x: x.repeat(3, 1, 1)) |
|
|
]) |
|
|
]) |
|
|
source_dataset_train = SVHN( |
|
|
source_dataset_train = SVHN( |
|
|
'./input', 'train', transform=source_transform, download=True) |
|
|
|
|
|
|
|
|
'input/', 'train', transform=source_transform, download=True) |
|
|
target_dataset_train = MNIST( |
|
|
target_dataset_train = MNIST( |
|
|
'./input', 'train', transform=target_transform, download=True) |
|
|
|
|
|
|
|
|
'input/', 'train', transform=target_transform, download=True) |
|
|
target_dataset_test = MNIST( |
|
|
target_dataset_test = MNIST( |
|
|
'./input', 'test', transform=target_transform, download=True) |
|
|
|
|
|
|
|
|
'input/', 'test', transform=target_transform, download=True) |
|
|
source_train_loader = DataLoader( |
|
|
source_train_loader = DataLoader( |
|
|
source_dataset_train, args.batch_size, shuffle=True, |
|
|
source_dataset_train, args.batch_size, shuffle=True, |
|
|
drop_last=True, |
|
|
drop_last=True, |