@ -8,11 +8,14 @@ from torchvision import transforms
from models import CNN
from trainer import train_source_cnn
from utils import get_logger
def main(args):
if not os.path.exists(args.logdir):
os.makedirs(args.logdir)
logger = get_logger(os.path.join(args.logdir, 'train_source.log'))
logger.info(args)
# data
source_transform = transforms.Compose([