fnakamura
6 years ago
2 changed files with 67 additions and 18 deletions
@ -0,0 +1,62 @@ |
|||
import argparse |
|||
import os |
|||
|
|||
from torch import nn, optim |
|||
from torch.utils.data import DataLoader |
|||
from torchvision.datasets import SVHN |
|||
from torchvision import transforms |
|||
|
|||
from models import CNN |
|||
from trainer import train_source_cnn |
|||
|
|||
|
|||
def main(args): |
|||
if not os.path.exists(args.logdir): |
|||
os.makedirs(args.logdir) |
|||
|
|||
# data |
|||
source_transform = transforms.Compose([ |
|||
transforms.ToTensor()] |
|||
) |
|||
source_dataset_train = SVHN( |
|||
'./input', 'train', transform=source_transform, download=True) |
|||
source_dataset_test = SVHN( |
|||
'./input', 'test', transform=source_transform, download=True) |
|||
source_train_loader = DataLoader( |
|||
source_dataset_train, args.batch_size, shuffle=True, |
|||
drop_last=True, |
|||
num_workers=args.n_workers) |
|||
source_test_loader = DataLoader( |
|||
source_dataset_test, args.batch_size, shuffle=False, |
|||
num_workers=args.n_workers) |
|||
|
|||
# train source CNN |
|||
source_cnn = CNN(in_channels=args.in_channels).to(args.device) |
|||
criterion = nn.CrossEntropyLoss() |
|||
optimizer = optim.Adam( |
|||
source_cnn.parameters(), |
|||
lr=args.lr, weight_decay=args.weight_decay) |
|||
source_cnn = train_source_cnn( |
|||
source_cnn, source_train_loader, source_test_loader, |
|||
criterion, optimizer, args=args) |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
parser = argparse.ArgumentParser() |
|||
# NN |
|||
parser.add_argument('--in_channels', type=int, default=3) |
|||
parser.add_argument('--n_classes', type=int, default=10) |
|||
parser.add_argument('--trained', type=str, default='') |
|||
parser.add_argument('--slope', type=float, default=0.2) |
|||
# train |
|||
parser.add_argument('--lr', type=float, default=1e-3) |
|||
parser.add_argument('--weight_decay', type=float, default=2.5e-5) |
|||
parser.add_argument('--epochs', type=int, default=50) |
|||
parser.add_argument('--batch_size', type=int, default=128) |
|||
# misc |
|||
parser.add_argument('--device', type=str, default='cuda:0') |
|||
parser.add_argument('--n_workers', type=int, default=0) |
|||
parser.add_argument('--logdir', type=str, default='outputs/garbage') |
|||
parser.add_argument('--message', '-m', type=str, default='') |
|||
args, unknown = parser.parse_known_args() |
|||
main(args) |
Loading…
Reference in new issue