diff --git a/trainer.py b/trainer.py index 3362704..47b37ec 100644 --- a/trainer.py +++ b/trainer.py @@ -83,7 +83,7 @@ def adversarial( args=None ): source_cnn.eval() - target_cnn.train() + target_cnn.encoder.train() discriminator.train() losses, d_losses = AverageMeter(), AverageMeter() @@ -106,19 +106,17 @@ def adversarial( # train Discriminator D_output_source = discriminator(D_input_source) D_output_target = discriminator(D_input_target) - d_loss_source = d_criterion(D_output_source, D_target_source) - d_loss_target = d_criterion(D_output_target, D_target_target) - d_loss = d_loss_source + d_loss_target + D_output = torch.cat([D_output_source, D_output_target], dim=0) + D_target = torch.cat([D_target_source, D_target_target], dim=0) + d_loss = criterion(D_output, D_target) d_optimizer.zero_grad() - d_loss.backward(retain_graph=True) + d_loss.backward() d_optimizer.step() d_losses.update(d_loss.item(), bs) # train Target - ''' D_input_target = target_cnn.encoder(target_data) D_output_target = discriminator(D_input_target) - ''' loss = criterion(D_output_target, D_target_source) optimizer.zero_grad() loss.backward()