diff --git a/train_source.py b/train_source.py index edcd74c..0481722 100644 --- a/train_source.py +++ b/train_source.py @@ -8,11 +8,14 @@ from torchvision import transforms from models import CNN from trainer import train_source_cnn +from utils import get_logger def main(args): if not os.path.exists(args.logdir): os.makedirs(args.logdir) + logger = get_logger(os.path.join(args.logdir, 'train_source.log')) + logger.info(args) # data source_transform = transforms.Compose([