|
@ -16,31 +16,36 @@ def train_source_cnn( |
|
|
source_cnn, train_loader, test_loader, criterion, optimizer, |
|
|
source_cnn, train_loader, test_loader, criterion, optimizer, |
|
|
args=None |
|
|
args=None |
|
|
): |
|
|
): |
|
|
best_score = None |
|
|
|
|
|
for epoch_i in range(1, 1 + args.epochs): |
|
|
|
|
|
start_time = time() |
|
|
|
|
|
training = train( |
|
|
|
|
|
source_cnn, train_loader, criterion, optimizer, args=args) |
|
|
|
|
|
validation = validate( |
|
|
|
|
|
source_cnn, test_loader, criterion, args=args) |
|
|
|
|
|
log = 'Epoch {}/{} '.format(epoch_i, args.epochs) |
|
|
|
|
|
log += '| Train/Loss {:.3f} Acc {:.3f} '.format( |
|
|
|
|
|
training['loss'], training['acc']) |
|
|
|
|
|
log += '| Val/Loss {:.3f} Acc {:.3f} '.format( |
|
|
|
|
|
validation['loss'], validation['acc']) |
|
|
|
|
|
log += 'Time {:.2f}s'.format(time() - start_time) |
|
|
|
|
|
logger.info(log) |
|
|
|
|
|
|
|
|
|
|
|
# save |
|
|
|
|
|
is_best = (best_score is None or validation['acc'] > best_score) |
|
|
|
|
|
best_score = validation['acc'] if is_best else best_score |
|
|
|
|
|
state_dict = { |
|
|
|
|
|
'model': source_cnn.state_dict(), |
|
|
|
|
|
'optimizer': optimizer.state_dict(), |
|
|
|
|
|
'epoch': epoch_i, |
|
|
|
|
|
'val/acc': best_score, |
|
|
|
|
|
} |
|
|
|
|
|
save(args.logdir, state_dict, is_best) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
best_score = None |
|
|
|
|
|
for epoch_i in range(1, 1 + args.epochs): |
|
|
|
|
|
start_time = time() |
|
|
|
|
|
training = train( |
|
|
|
|
|
source_cnn, train_loader, criterion, optimizer, args=args) |
|
|
|
|
|
validation = validate( |
|
|
|
|
|
source_cnn, test_loader, criterion, args=args) |
|
|
|
|
|
log = 'Epoch {}/{} '.format(epoch_i, args.epochs) |
|
|
|
|
|
log += '| Train/Loss {:.3f} Acc {:.3f} '.format( |
|
|
|
|
|
training['loss'], training['acc']) |
|
|
|
|
|
log += '| Val/Loss {:.3f} Acc {:.3f} '.format( |
|
|
|
|
|
validation['loss'], validation['acc']) |
|
|
|
|
|
log += 'Time {:.2f}s'.format(time() - start_time) |
|
|
|
|
|
logger.info(log) |
|
|
|
|
|
|
|
|
|
|
|
# save |
|
|
|
|
|
is_best = (best_score is None or validation['acc'] > best_score) |
|
|
|
|
|
best_score = validation['acc'] if is_best else best_score |
|
|
|
|
|
state_dict = { |
|
|
|
|
|
'model': source_cnn.state_dict(), |
|
|
|
|
|
'optimizer': optimizer.state_dict(), |
|
|
|
|
|
'epoch': epoch_i, |
|
|
|
|
|
'val/acc': best_score, |
|
|
|
|
|
} |
|
|
|
|
|
save(args.logdir, state_dict, is_best) |
|
|
|
|
|
logger.info('Best val. acc.: {}'.format(best_score)) |
|
|
|
|
|
except KeyboardInterrupt as ke: |
|
|
|
|
|
logger.info('\n============ Summary ============= \n') |
|
|
|
|
|
logger.info('Best val. acc.: {}'.format(best_score)) |
|
|
|
|
|
|
|
|
return source_cnn |
|
|
return source_cnn |
|
|
|
|
|
|
|
@ -54,41 +59,46 @@ def train_target_cnn( |
|
|
validation = validate(source_cnn, target_test_loader, criterion, args=args) |
|
|
validation = validate(source_cnn, target_test_loader, criterion, args=args) |
|
|
log_source = 'Source/Acc {:.3f} '.format(validation['acc']) |
|
|
log_source = 'Source/Acc {:.3f} '.format(validation['acc']) |
|
|
|
|
|
|
|
|
best_score = None |
|
|
|
|
|
for epoch_i in range(1, 1 + args.epochs): |
|
|
|
|
|
start_time = time() |
|
|
|
|
|
training = adversarial( |
|
|
|
|
|
source_cnn, target_cnn, discriminator, |
|
|
|
|
|
source_train_loader, target_train_loader, |
|
|
|
|
|
criterion, criterion, |
|
|
|
|
|
optimizer, d_optimizer, |
|
|
|
|
|
args=args |
|
|
|
|
|
) |
|
|
|
|
|
validation = validate( |
|
|
|
|
|
target_cnn, target_test_loader, criterion, args=args) |
|
|
|
|
|
validation2 = validate( |
|
|
|
|
|
target_cnn, target_train_loader, criterion, args=args) |
|
|
|
|
|
log = 'Epoch {}/{} '.format(epoch_i, args.epochs) |
|
|
|
|
|
log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format( |
|
|
|
|
|
training['d/loss'], training['target/loss']) |
|
|
|
|
|
log += '[Val] Target/Loss {:.3f} Target/Acc {:.3f} '.format( |
|
|
|
|
|
validation['loss'], validation['acc']) |
|
|
|
|
|
log += log_source |
|
|
|
|
|
log += '[Train] Target/Loss {:.3f} Target/Acc {:.3f} '.format( |
|
|
|
|
|
validation2['loss'], validation2['acc']) |
|
|
|
|
|
log += 'Time {:.2f}s'.format(time() - start_time) |
|
|
|
|
|
logger.info(log) |
|
|
|
|
|
|
|
|
|
|
|
# save |
|
|
|
|
|
is_best = (best_score is None or validation['acc'] > best_score) |
|
|
|
|
|
best_score = validation['acc'] if is_best else best_score |
|
|
|
|
|
state_dict = { |
|
|
|
|
|
'model': target_cnn.state_dict(), |
|
|
|
|
|
'optimizer': optimizer.state_dict(), |
|
|
|
|
|
'epoch': epoch_i, |
|
|
|
|
|
'val/acc': best_score, |
|
|
|
|
|
} |
|
|
|
|
|
save(args.logdir, state_dict, is_best) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
best_score = None |
|
|
|
|
|
for epoch_i in range(1, 1 + args.epochs): |
|
|
|
|
|
start_time = time() |
|
|
|
|
|
training = adversarial( |
|
|
|
|
|
source_cnn, target_cnn, discriminator, |
|
|
|
|
|
source_train_loader, target_train_loader, |
|
|
|
|
|
criterion, criterion, |
|
|
|
|
|
optimizer, d_optimizer, |
|
|
|
|
|
args=args |
|
|
|
|
|
) |
|
|
|
|
|
validation = validate( |
|
|
|
|
|
target_cnn, target_test_loader, criterion, args=args) |
|
|
|
|
|
validation2 = validate( |
|
|
|
|
|
target_cnn, target_train_loader, criterion, args=args) |
|
|
|
|
|
log = 'Epoch {}/{} '.format(epoch_i, args.epochs) |
|
|
|
|
|
log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format( |
|
|
|
|
|
training['d/loss'], training['target/loss']) |
|
|
|
|
|
log += '[Val] Target/Loss {:.3f} Target/Acc {:.3f} '.format( |
|
|
|
|
|
validation['loss'], validation['acc']) |
|
|
|
|
|
log += log_source |
|
|
|
|
|
log += '[Train] Target/Loss {:.3f} Target/Acc {:.3f} '.format( |
|
|
|
|
|
validation2['loss'], validation2['acc']) |
|
|
|
|
|
log += 'Time {:.2f}s'.format(time() - start_time) |
|
|
|
|
|
logger.info(log) |
|
|
|
|
|
|
|
|
|
|
|
# save |
|
|
|
|
|
is_best = (best_score is None or validation['acc'] > best_score) |
|
|
|
|
|
best_score = validation['acc'] if is_best else best_score |
|
|
|
|
|
state_dict = { |
|
|
|
|
|
'model': target_cnn.state_dict(), |
|
|
|
|
|
'optimizer': optimizer.state_dict(), |
|
|
|
|
|
'epoch': epoch_i, |
|
|
|
|
|
'val/acc': best_score, |
|
|
|
|
|
} |
|
|
|
|
|
save(args.logdir, state_dict, is_best) |
|
|
|
|
|
logger.info('Best val. acc.: {}'.format(best_score)) |
|
|
|
|
|
except KeyboardInterrupt as ke: |
|
|
|
|
|
logger.info('\n============ Summary ============= \n') |
|
|
|
|
|
logger.info('Best val. acc.: {}'.format(best_score)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def adversarial( |
|
|
def adversarial( |
|
|