|
@ -12,7 +12,7 @@ class ReverseLayerF(Function): |
|
|
@staticmethod |
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
def backward(ctx, grad_output): |
|
|
output = grad_output.neg() * ctx.alpha |
|
|
output = grad_output.neg() * ctx.alpha |
|
|
|
|
|
|
|
|
|
|
|
#print("reverse gradient is {}".format(output)) |
|
|
return output, None |
|
|
return output, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|