|
|
@ -1,3 +1,4 @@ |
|
|
|
from logging import getLogger |
|
|
|
from time import time |
|
|
|
|
|
|
|
import numpy as np |
|
|
@ -7,6 +8,9 @@ import torch |
|
|
|
from utils import AverageMeter, save |
|
|
|
|
|
|
|
|
|
|
|
logger = getLogger('adda.trainer') |
|
|
|
|
|
|
|
|
|
|
|
def train_source_cnn( |
|
|
|
source_cnn, train_loader, test_loader, criterion, optimizer, |
|
|
|
args=None |
|
|
@ -24,7 +28,7 @@ def train_source_cnn( |
|
|
|
log += 'Val/Loss {:.3f} Val/Acc {:.3f} '.format( |
|
|
|
validation['loss'], validation['acc']) |
|
|
|
log += 'Time {:.2f}s'.format(time() - start_time) |
|
|
|
print(log) |
|
|
|
logger.info(log) |
|
|
|
|
|
|
|
# save |
|
|
|
is_best = (best_score is None or validation['acc'] > best_score) |
|
|
@ -68,7 +72,7 @@ def train_target_cnn( |
|
|
|
validation['loss'], validation['acc']) |
|
|
|
log += log_source |
|
|
|
log += 'Time {:.2f}s'.format(time() - start_time) |
|
|
|
print(log) |
|
|
|
logger.info(log) |
|
|
|
|
|
|
|
|
|
|
|
def adversarial( |
|
|
|