|
@ -8,11 +8,14 @@ from torchvision import transforms |
|
|
|
|
|
|
|
|
from models import CNN |
|
|
from models import CNN |
|
|
from trainer import train_source_cnn |
|
|
from trainer import train_source_cnn |
|
|
|
|
|
from utils import get_logger |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(args): |
|
|
def main(args): |
|
|
if not os.path.exists(args.logdir): |
|
|
if not os.path.exists(args.logdir): |
|
|
os.makedirs(args.logdir) |
|
|
os.makedirs(args.logdir) |
|
|
|
|
|
logger = get_logger(os.path.join(args.logdir, 'train_source.log')) |
|
|
|
|
|
logger.info(args) |
|
|
|
|
|
|
|
|
# data |
|
|
# data |
|
|
source_transform = transforms.Compose([ |
|
|
source_transform = transforms.Compose([ |
|
|