|
|
@ -3,7 +3,6 @@ |
|
|
|
import torch.nn as nn |
|
|
|
from .functions import ReverseLayerF |
|
|
|
from torchvision import models |
|
|
|
import params |
|
|
|
|
|
|
|
|
|
|
|
class Classifier(nn.Module): |
|
|
@ -43,6 +42,45 @@ class Classifier(nn.Module): |
|
|
|
|
|
|
|
return class_output |
|
|
|
|
|
|
|
class MNISTmodel(nn.Module): |
|
|
|
""" MNIST architecture""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super(MNISTmodel, self).__init__() |
|
|
|
self.restored = False |
|
|
|
|
|
|
|
self.feature = nn.Sequential( |
|
|
|
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(5, 5)), # 1 28 28, 32 24 24 |
|
|
|
nn.ReLU(inplace=True), |
|
|
|
nn.MaxPool2d(kernel_size=(2, 2)), # 32 12 12 |
|
|
|
nn.Conv2d(in_channels=32, out_channels=48, kernel_size=(5, 5)), # 48 8 8 |
|
|
|
nn.ReLU(inplace=True), |
|
|
|
nn.MaxPool2d(kernel_size=(2, 2)), # 48 4 4 |
|
|
|
) |
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
|
nn.Linear(48*4*4, 100), |
|
|
|
nn.ReLU(inplace=True), |
|
|
|
nn.Linear(100, 100), |
|
|
|
nn.ReLU(inplace=True), |
|
|
|
nn.Linear(100, 10), |
|
|
|
) |
|
|
|
|
|
|
|
self.discriminator = nn.Sequential( |
|
|
|
nn.Linear(48*4*4, 100), |
|
|
|
nn.ReLU(inplace=True), |
|
|
|
nn.Linear(100, 2), |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, input_data, alpha): |
|
|
|
input_data = input_data.expand(input_data.data.shape[0], 1, 28, 28) |
|
|
|
feature = self.feature(input_data) |
|
|
|
feature = feature.view(-1, 48 * 4 * 4) |
|
|
|
reverse_feature = ReverseLayerF.apply(feature, alpha) |
|
|
|
class_output = self.classifier(feature) |
|
|
|
domain_output = self.discriminator(reverse_feature) |
|
|
|
|
|
|
|
return class_output, domain_output |
|
|
|
|
|
|
|
class SVHNmodel(nn.Module): |
|
|
|
""" SVHN architecture""" |
|
|
@ -90,7 +128,6 @@ class SVHNmodel(nn.Module): |
|
|
|
|
|
|
|
return class_output, domain_output |
|
|
|
|
|
|
|
|
|
|
|
class AlexModel(nn.Module): |
|
|
|
""" AlexNet pretrained on imagenet for Office dataset""" |
|
|
|
|
|
|
@ -119,7 +156,7 @@ class AlexModel(nn.Module): |
|
|
|
nn.Dropout(0.5), |
|
|
|
nn.Linear(4096, 256), |
|
|
|
nn.ReLU(inplace=True), |
|
|
|
nn.Linear(256, params.class_num_src), |
|
|
|
nn.Linear(256, 31), |
|
|
|
) |
|
|
|
|
|
|
|
self.discriminator = nn.Sequential( |
|
|
|