|
|
@ -83,7 +83,7 @@ def adversarial( |
|
|
|
args=None |
|
|
|
): |
|
|
|
source_cnn.eval() |
|
|
|
target_cnn.train() |
|
|
|
target_cnn.encoder.train() |
|
|
|
discriminator.train() |
|
|
|
|
|
|
|
losses, d_losses = AverageMeter(), AverageMeter() |
|
|
@ -106,19 +106,17 @@ def adversarial( |
|
|
|
# train Discriminator |
|
|
|
D_output_source = discriminator(D_input_source) |
|
|
|
D_output_target = discriminator(D_input_target) |
|
|
|
d_loss_source = d_criterion(D_output_source, D_target_source) |
|
|
|
d_loss_target = d_criterion(D_output_target, D_target_target) |
|
|
|
d_loss = d_loss_source + d_loss_target |
|
|
|
D_output = torch.cat([D_output_source, D_output_target], dim=0) |
|
|
|
D_target = torch.cat([D_target_source, D_target_target], dim=0) |
|
|
|
d_loss = criterion(D_output, D_target) |
|
|
|
d_optimizer.zero_grad() |
|
|
|
d_loss.backward(retain_graph=True) |
|
|
|
d_loss.backward() |
|
|
|
d_optimizer.step() |
|
|
|
d_losses.update(d_loss.item(), bs) |
|
|
|
|
|
|
|
# train Target |
|
|
|
''' |
|
|
|
D_input_target = target_cnn.encoder(target_data) |
|
|
|
D_output_target = discriminator(D_input_target) |
|
|
|
''' |
|
|
|
loss = criterion(D_output_target, D_target_source) |
|
|
|
optimizer.zero_grad() |
|
|
|
loss.backward() |
|
|
|