diff --git a/core/train_source_rn50.py b/core/train_source_rn50.py index db3ebc1..2a0b8ea 100644 --- a/core/train_source_rn50.py +++ b/core/train_source_rn50.py @@ -25,12 +25,7 @@ def main(args): # train source CNN source_cnn = CNN(in_channels=args.in_channels, srcTrain=True).to(args.device) - # for param in source_cnn.encoder.parameters(): - # param.requires_grad = False criterion = nn.CrossEntropyLoss() - # optimizer = optim.Adam( - # source_cnn.classifier.parameters(), - # lr=args.lr, weight_decay=args.weight_decay) optimizer = optim.Adam( source_cnn.parameters(), lr=args.lr, weight_decay=args.weight_decay) diff --git a/models/resnet50off.py b/models/resnet50off.py index ffeca90..e426ccf 100644 --- a/models/resnet50off.py +++ b/models/resnet50off.py @@ -32,10 +32,6 @@ class Encoder(nn.Module): def __init__(self, in_channels=3, h=256, dropout=0.5, srcTrain=False): super(Encoder, self).__init__() - # resnetModel = models.resnet50(pretrained=True) - # feature_map = list(resnetModel.children()) - # feature_map.pop() - # self.feature_extractor = nn.Sequential(*feature_map) rnMod = ResNet50Mod() self.feature_extractor = rnMod.freezed_rn50 self.layer4 = rnMod.layer4 @@ -47,10 +43,8 @@ class Encoder(nn.Module): def forward(self, x): x = x.expand(x.data.shape[0], 3, 224, 224) x = self.feature_extractor(x) - ### x = self.layer4(x) x = self.avgpool(x) - ### x = x.view(x.size(0), -1) return x @@ -60,10 +54,6 @@ class Classifier(nn.Module): super(Classifier, self).__init__() self.l1 = nn.Linear(2048, n_classes) - # for m in self.modules(): - # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): - # nn.init.kaiming_normal_(m.weight) - def forward(self, x): x = self.l1(x) return x @@ -90,20 +80,12 @@ class Discriminator(nn.Module): self.l1 = nn.Linear(2048, h) self.l2 = nn.Linear(h, h) self.l3 = nn.Linear(h, 2) - ### self.l4 = nn.LogSoftmax(dim=1) - ### self.slope = args.slope - # for m in self.modules(): - # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): - # nn.init.kaiming_normal_(m.weight) - def forward(self, x): x = F.leaky_relu(self.l1(x), self.slope) x = F.leaky_relu(self.l2(x), self.slope) x = self.l3(x) - ### x = self.l4(x) - ### return x diff --git a/utils/altutils.py b/utils/altutils.py index 1bd05b1..6d094fb 100644 --- a/utils/altutils.py +++ b/utils/altutils.py @@ -37,7 +37,16 @@ def setLogger(logFilePath): return logger def get_office(dataset_root, batch_size, category): - """Get Office datasets loader.""" + """Get Office datasets loader + + Args: + dataset_root (str): path to the dataset folder + batch_size (int): batch size + category (str): category of Office dataset (amazon, webcam, dslr) + + Returns: + obj: dataloader object for Office dataset + """ # image pre-processing pre_process = transforms.Compose([ transforms.Resize(224),