Browse Source

update

master
fnakamura 6 years ago
parent
commit
1d1b896379
  1. 12
      trainer.py

12
trainer.py

@ -83,7 +83,7 @@ def adversarial(
args=None
):
source_cnn.eval()
target_cnn.train()
target_cnn.encoder.train()
discriminator.train()
losses, d_losses = AverageMeter(), AverageMeter()
@ -106,19 +106,17 @@ def adversarial(
# train Discriminator
D_output_source = discriminator(D_input_source)
D_output_target = discriminator(D_input_target)
d_loss_source = d_criterion(D_output_source, D_target_source)
d_loss_target = d_criterion(D_output_target, D_target_target)
d_loss = d_loss_source + d_loss_target
D_output = torch.cat([D_output_source, D_output_target], dim=0)
D_target = torch.cat([D_target_source, D_target_target], dim=0)
d_loss = criterion(D_output, D_target)
d_optimizer.zero_grad()
d_loss.backward(retain_graph=True)
d_loss.backward()
d_optimizer.step()
d_losses.update(d_loss.item(), bs)
# train Target
'''
D_input_target = target_cnn.encoder(target_data)
D_output_target = discriminator(D_input_target)
'''
loss = criterion(D_output_target, D_target_source)
optimizer.zero_grad()
loss.backward()

Loading…
Cancel
Save