Browse Source

58% office task commit.

master
wogong 7 years ago
parent
commit
e751b5984d
  1. 22
      core/dann.py
  2. 43
      models/model.py
  3. 16
      office.py

22
core/dann.py

@ -18,12 +18,20 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
#################### ####################
# setup criterion and optimizer # setup criterion and optimizer
# parameter_list = [
# # {"params": model.feature.parameters(), "lr": 1e-5},
# # {"params": model.classifier.parameters(), "lr": 1e-4},
# # {"params": model.discriminator.parameters(), "lr": 1e-4}
# # ]
if params.src_dataset == 'mnist' or params.tgt_dataset == 'mnist':
print("training mnist task")
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
else:
print("training office task")
parameter_list = [
{"params": model.features.parameters(), "lr": 1e-5},
{"params": model.fc.parameters(), "lr": 1e-5},
{"params": model.bottleneck.parameters(), "lr": 1e-4},
{"params": model.classifier.parameters(), "lr": 1e-4},
{"params": model.discriminator.parameters(), "lr": 1e-4}
]
optimizer = optim.SGD(parameter_list)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
@ -44,6 +52,8 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
p = float(step + epoch * len_dataloader) / params.num_epochs / len_dataloader p = float(step + epoch * len_dataloader) / params.num_epochs / len_dataloader
alpha = 2. / (1. + np.exp(-10 * p)) - 1 alpha = 2. / (1. + np.exp(-10 * p)) - 1
if params.src_dataset == 'mnist' or params.tgt_dataset == 'mnist':
print("training mnist task")
adjust_learning_rate(optimizer, p) adjust_learning_rate(optimizer, p)
# prepare domain label # prepare domain label
@ -90,7 +100,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
# eval model on test set # eval model on test set
if ((epoch + 1) % params.eval_step == 0): if ((epoch + 1) % params.eval_step == 0):
print("eval on target domain") print("eval on target domain")
eval(model, tgt_data_loader_eval)
eval(model, tgt_data_loader)
print("eval on source domain") print("eval on source domain")
eval_src(model, src_data_loader) eval_src(model, src_data_loader)

43
models/model.py

@ -154,34 +154,28 @@ class AlexModel(nn.Module):
self.features = model_alexnet.features self.features = model_alexnet.features
# self.classifier = nn.Sequential()
# for i in range(5):
# self.classifier.add_module(
# "classifier" + str(i), model_alexnet.classifier[i])
# self.__in_features = model_alexnet.classifier[4].in_features
# self.classifier.add_module('classifier5', nn.Dropout())
# self.classifier.add_module('classifier6', nn.Linear(self.__in_features, 256))
# self.classifier.add_module('classifier7', nn.BatchNorm2d(256))
# self.classifier.add_module('classifier8', nn.ReLU())
# self.classifier.add_module('classifier9', nn.Dropout(0.5))
# self.classifier.add_module('classifier10', nn.Linear(256, params.class_num_src))
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
self.fc = nn.Sequential()
for i in range(6):
self.fc.add_module("classifier" + str(i), model_alexnet.classifier[i])
self.__in_features = model_alexnet.classifier[6].in_features # 4096
self.bottleneck = nn.Sequential(
nn.Linear(4096, 256), nn.Linear(4096, 256),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(256, 31),
nn.Dropout(),
)
self.classifier = nn.Sequential(
nn.Linear(256, 31)
) )
self.discriminator = nn.Sequential( self.discriminator = nn.Sequential(
nn.Linear(256 * 6 * 6, 1024),
nn.Linear(256, 1024),
nn.ReLU(), nn.ReLU(),
nn.Dropout(0.5),
nn.Dropout(),
nn.Linear(1024, 1024), nn.Linear(1024, 1024),
nn.ReLU(), nn.ReLU(),
nn.Dropout(0.5),
nn.Dropout(),
nn.Linear(1024, 2), nn.Linear(1024, 2),
) )
@ -189,11 +183,12 @@ class AlexModel(nn.Module):
input_data = input_data.expand(input_data.data.shape[0], 3, 227, 227) input_data = input_data.expand(input_data.data.shape[0], 3, 227, 227)
feature = self.features(input_data) feature = self.features(input_data)
feature = feature.view(-1, 256*6*6) feature = feature.view(-1, 256*6*6)
fc = self.fc(feature)
bottleneck = self.bottleneck(fc)
reverse_feature = ReverseLayerF.apply(feature, alpha)
reverse_bottleneck = ReverseLayerF.apply(bottleneck, alpha)
class_output = self.classifier(feature)
domain_output = self.discriminator(reverse_feature)
class_output = self.classifier(bottleneck)
domain_output = self.discriminator(reverse_bottleneck)
return class_output, domain_output return class_output, domain_output

16
main_office.py → office.py

@ -12,7 +12,7 @@ class Config(object):
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN')) model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN'))
# params for datasets and data loader # params for datasets and data loader
batch_size = 128
batch_size = 32
# params for source dataset # params for source dataset
src_dataset = "amazon31" src_dataset = "amazon31"
@ -33,10 +33,10 @@ class Config(object):
# params for training dann # params for training dann
## for office ## for office
num_epochs = 1000
log_step = 10 # iters
num_epochs = 4000
log_step = 25 # iters
save_step = 500 save_step = 500
eval_step = 5 # epochs
eval_step = 50 # epochs
manual_seed = 8888 manual_seed = 8888
alpha = 0 alpha = 0
@ -50,17 +50,17 @@ params = Config()
init_random_seed(params.manual_seed) init_random_seed(params.manual_seed)
# load dataset # load dataset
src_data_loader = get_data_loader(params.src_dataset)
tgt_data_loader = get_data_loader(params.tgt_dataset)
src_data_loader = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size)
tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root, params.batch_size)
# load dann model # load dann model
dann = init_model(net=AlexModel(), restore=params.dann_restore)
dann = init_model(net=AlexModel(), restore=None)
# train dann model # train dann model
print("Start training dann model.") print("Start training dann model.")
if not (dann.restored and params.dann_restore): if not (dann.restored and params.dann_restore):
dann = train_dann(dann, src_data_loader, tgt_data_loader, tgt_data_loader)
dann = train_dann(dann, params, src_data_loader, tgt_data_loader, tgt_data_loader)
# eval dann model # eval dann model
print("Evaluating dann for source domain") print("Evaluating dann for source domain")
Loading…
Cancel
Save