Implementation of "Adversarial Discriminative Domain Adaptation" in PyTorch
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
2.3 KiB

6 years ago
import argparse
import os
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import SVHN
from torchvision import transforms
from models import CNN
from trainer import train_source_cnn
6 years ago
from utils import get_logger
6 years ago
def main(args):
if not os.path.exists(args.logdir):
os.makedirs(args.logdir)
6 years ago
logger = get_logger(os.path.join(args.logdir, 'train_source.log'))
logger.info(args)
6 years ago
# data
source_transform = transforms.Compose([
transforms.ToTensor()]
)
source_dataset_train = SVHN(
'./input', 'train', transform=source_transform, download=True)
source_dataset_test = SVHN(
'./input', 'test', transform=source_transform, download=True)
source_train_loader = DataLoader(
source_dataset_train, args.batch_size, shuffle=True,
drop_last=True,
num_workers=args.n_workers)
source_test_loader = DataLoader(
source_dataset_test, args.batch_size, shuffle=False,
num_workers=args.n_workers)
# train source CNN
source_cnn = CNN(in_channels=args.in_channels).to(args.device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
source_cnn.parameters(),
lr=args.lr, weight_decay=args.weight_decay)
source_cnn = train_source_cnn(
source_cnn, source_train_loader, source_test_loader,
criterion, optimizer, args=args)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# NN
parser.add_argument('--in_channels', type=int, default=3)
parser.add_argument('--n_classes', type=int, default=10)
parser.add_argument('--trained', type=str, default='')
parser.add_argument('--slope', type=float, default=0.2)
# train
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=2.5e-5)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=128)
# misc
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--n_workers', type=int, default=0)
parser.add_argument('--logdir', type=str, default='outputs/garbage')
parser.add_argument('--message', '-m', type=str, default='')
args, unknown = parser.parse_known_args()
main(args)