Browse Source

Delete unnecessary comment lines

master
Fazil Altinel 4 years ago
parent
commit
0738d8eff7
  1. 5
      core/train_source_rn50.py
  2. 18
      models/resnet50off.py
  3. 11
      utils/altutils.py

5
core/train_source_rn50.py

@ -25,12 +25,7 @@ def main(args):
# train source CNN
source_cnn = CNN(in_channels=args.in_channels, srcTrain=True).to(args.device)
# for param in source_cnn.encoder.parameters():
# param.requires_grad = False
criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(
# source_cnn.classifier.parameters(),
# lr=args.lr, weight_decay=args.weight_decay)
optimizer = optim.Adam(
source_cnn.parameters(),
lr=args.lr, weight_decay=args.weight_decay)

18
models/resnet50off.py

@ -32,10 +32,6 @@ class Encoder(nn.Module):
def __init__(self, in_channels=3, h=256, dropout=0.5, srcTrain=False):
super(Encoder, self).__init__()
# resnetModel = models.resnet50(pretrained=True)
# feature_map = list(resnetModel.children())
# feature_map.pop()
# self.feature_extractor = nn.Sequential(*feature_map)
rnMod = ResNet50Mod()
self.feature_extractor = rnMod.freezed_rn50
self.layer4 = rnMod.layer4
@ -47,10 +43,8 @@ class Encoder(nn.Module):
def forward(self, x):
x = x.expand(x.data.shape[0], 3, 224, 224)
x = self.feature_extractor(x)
###
x = self.layer4(x)
x = self.avgpool(x)
###
x = x.view(x.size(0), -1)
return x
@ -60,10 +54,6 @@ class Classifier(nn.Module):
super(Classifier, self).__init__()
self.l1 = nn.Linear(2048, n_classes)
# for m in self.modules():
# if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
# nn.init.kaiming_normal_(m.weight)
def forward(self, x):
x = self.l1(x)
return x
@ -90,20 +80,12 @@ class Discriminator(nn.Module):
self.l1 = nn.Linear(2048, h)
self.l2 = nn.Linear(h, h)
self.l3 = nn.Linear(h, 2)
###
self.l4 = nn.LogSoftmax(dim=1)
###
self.slope = args.slope
# for m in self.modules():
# if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
# nn.init.kaiming_normal_(m.weight)
def forward(self, x):
x = F.leaky_relu(self.l1(x), self.slope)
x = F.leaky_relu(self.l2(x), self.slope)
x = self.l3(x)
###
x = self.l4(x)
###
return x

11
utils/altutils.py

@ -37,7 +37,16 @@ def setLogger(logFilePath):
return logger
def get_office(dataset_root, batch_size, category):
"""Get Office datasets loader."""
"""Get Office datasets loader
Args:
dataset_root (str): path to the dataset folder
batch_size (int): batch size
category (str): category of Office dataset (amazon, webcam, dslr)
Returns:
obj: dataloader object for Office dataset
"""
# image pre-processing
pre_process = transforms.Compose([
transforms.Resize(224),

Loading…
Cancel
Save