Browse Source

ResNet changes for Office dataset

master
Fazil Altinel 4 years ago
parent
commit
06baf21197
  1. 6
      README.md
  2. 6
      core/experiment_rn50.py
  3. 15
      core/train_source_rn50.py
  4. 130
      core/trainer.py
  5. 3
      main.py
  6. 50
      models/resnet50off.py

6
README.md

@ -1,5 +1,5 @@
# ADDA.PyTorch-resnet # ADDA.PyTorch-resnet
Implementation of "Adversarial Discriminative Domain Adapation" in PyTorch
Implementation of "Adversarial Discriminative Domain Adaptation" in PyTorch
This repo is mostly based on https://github.com/Fujiki-Nakamura/ADDA.PyTorch This repo is mostly based on https://github.com/Fujiki-Nakamura/ADDA.PyTorch
@ -19,8 +19,8 @@ $ python main.py --logdir outputs --trained outputs/best_model.pt --slope 0.2
For training on Office dataset using ResNet-50 For training on Office dataset using ResNet-50
``` ```
$ python core/train_source_rn50.py --n_classes 31 --logdir outputs
$ python main.py --n_classes 31 --trained outputs/best_model.pt --logdir outputs --model resnet50 --src-cat amazon --tgt-cat webcam
$ python core/train_source_rn50.py --n_classes 31 --lr 1e-4 --src_cat amazon --tgt_cat webcam
$ python main.py --n_classes 31 --trained outputs/garbage/best_model.pt --lr 1e-5 --d_lr 1e-4 --logdir outputs --model resnet50 --src-cat amazon --tgt-cat webcam
``` ```
## Result ## Result

6
core/experiment_rn50.py

@ -33,6 +33,10 @@ def run(args):
# train target CNN # train target CNN
target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device) target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device)
target_cnn.load_state_dict(source_cnn.state_dict()) target_cnn.load_state_dict(source_cnn.state_dict())
for param in source_cnn.parameters():
param.requires_grad = False
for param in target_cnn.classifier.parameters():
param.requires_grad = False
discriminator = Discriminator(args=args).to(args.device) discriminator = Discriminator(args=args).to(args.device)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam( optimizer = optim.Adam(
@ -40,7 +44,7 @@ def run(args):
lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
d_optimizer = optim.Adam( d_optimizer = optim.Adam(
discriminator.parameters(), discriminator.parameters(),
lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
lr=args.d_lr, betas=args.betas, weight_decay=args.weight_decay)
train_target_cnn( train_target_cnn(
source_cnn, target_cnn, discriminator, source_cnn, target_cnn, discriminator,
criterion, optimizer, d_optimizer, criterion, optimizer, d_optimizer,

15
core/train_source_rn50.py

@ -21,15 +21,21 @@ def main(args):
# data loaders # data loaders
dataset_root = os.environ["DATASETDIR"] dataset_root = os.environ["DATASETDIR"]
source_loader = get_office(dataset_root, args.batch_size, args.src_cat) source_loader = get_office(dataset_root, args.batch_size, args.src_cat)
target_loader = get_office(dataset_root, args.batch_size, args.tgt_cat)
# train source CNN # train source CNN
source_cnn = CNN(in_channels=args.in_channels).to(args.device)
source_cnn = CNN(in_channels=args.in_channels, srcTrain=True).to(args.device)
# for param in source_cnn.encoder.parameters():
# param.requires_grad = False
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(
# source_cnn.classifier.parameters(),
# lr=args.lr, weight_decay=args.weight_decay)
optimizer = optim.Adam( optimizer = optim.Adam(
source_cnn.parameters(), source_cnn.parameters(),
lr=args.lr, weight_decay=args.weight_decay) lr=args.lr, weight_decay=args.weight_decay)
source_cnn = train_source_cnn( source_cnn = train_source_cnn(
source_cnn, source_loader, source_loader,
source_cnn, source_loader, target_loader,
criterion, optimizer, args=args) criterion, optimizer, args=args)
@ -41,9 +47,9 @@ if __name__ == '__main__':
parser.add_argument('--trained', type=str, default='') parser.add_argument('--trained', type=str, default='')
parser.add_argument('--slope', type=float, default=0.2) parser.add_argument('--slope', type=float, default=0.2)
# train # train
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--weight_decay', type=float, default=2.5e-5) parser.add_argument('--weight_decay', type=float, default=2.5e-5)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--batch_size', type=int, default=32)
# misc # misc
parser.add_argument('--device', type=str, default='cuda:0') parser.add_argument('--device', type=str, default='cuda:0')
@ -52,5 +58,6 @@ if __name__ == '__main__':
parser.add_argument('--message', '-m', type=str, default='') parser.add_argument('--message', '-m', type=str, default='')
# office dataset categories # office dataset categories
parser.add_argument('--src_cat', type=str, default='amazon') parser.add_argument('--src_cat', type=str, default='amazon')
parser.add_argument('--tgt_cat', type=str, default='webcam')
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
main(args) main(args)

130
core/trainer.py

@ -16,31 +16,36 @@ def train_source_cnn(
source_cnn, train_loader, test_loader, criterion, optimizer, source_cnn, train_loader, test_loader, criterion, optimizer,
args=None args=None
): ):
best_score = None
for epoch_i in range(1, 1 + args.epochs):
start_time = time()
training = train(
source_cnn, train_loader, criterion, optimizer, args=args)
validation = validate(
source_cnn, test_loader, criterion, args=args)
log = 'Epoch {}/{} '.format(epoch_i, args.epochs)
log += '| Train/Loss {:.3f} Acc {:.3f} '.format(
training['loss'], training['acc'])
log += '| Val/Loss {:.3f} Acc {:.3f} '.format(
validation['loss'], validation['acc'])
log += 'Time {:.2f}s'.format(time() - start_time)
logger.info(log)
# save
is_best = (best_score is None or validation['acc'] > best_score)
best_score = validation['acc'] if is_best else best_score
state_dict = {
'model': source_cnn.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch_i,
'val/acc': best_score,
}
save(args.logdir, state_dict, is_best)
try:
best_score = None
for epoch_i in range(1, 1 + args.epochs):
start_time = time()
training = train(
source_cnn, train_loader, criterion, optimizer, args=args)
validation = validate(
source_cnn, test_loader, criterion, args=args)
log = 'Epoch {}/{} '.format(epoch_i, args.epochs)
log += '| Train/Loss {:.3f} Acc {:.3f} '.format(
training['loss'], training['acc'])
log += '| Val/Loss {:.3f} Acc {:.3f} '.format(
validation['loss'], validation['acc'])
log += 'Time {:.2f}s'.format(time() - start_time)
logger.info(log)
# save
is_best = (best_score is None or validation['acc'] > best_score)
best_score = validation['acc'] if is_best else best_score
state_dict = {
'model': source_cnn.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch_i,
'val/acc': best_score,
}
save(args.logdir, state_dict, is_best)
logger.info('Best val. acc.: {}'.format(best_score))
except KeyboardInterrupt as ke:
logger.info('\n============ Summary ============= \n')
logger.info('Best val. acc.: {}'.format(best_score))
return source_cnn return source_cnn
@ -54,41 +59,46 @@ def train_target_cnn(
validation = validate(source_cnn, target_test_loader, criterion, args=args) validation = validate(source_cnn, target_test_loader, criterion, args=args)
log_source = 'Source/Acc {:.3f} '.format(validation['acc']) log_source = 'Source/Acc {:.3f} '.format(validation['acc'])
best_score = None
for epoch_i in range(1, 1 + args.epochs):
start_time = time()
training = adversarial(
source_cnn, target_cnn, discriminator,
source_train_loader, target_train_loader,
criterion, criterion,
optimizer, d_optimizer,
args=args
)
validation = validate(
target_cnn, target_test_loader, criterion, args=args)
validation2 = validate(
target_cnn, target_train_loader, criterion, args=args)
log = 'Epoch {}/{} '.format(epoch_i, args.epochs)
log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format(
training['d/loss'], training['target/loss'])
log += '[Val] Target/Loss {:.3f} Target/Acc {:.3f} '.format(
validation['loss'], validation['acc'])
log += log_source
log += '[Train] Target/Loss {:.3f} Target/Acc {:.3f} '.format(
validation2['loss'], validation2['acc'])
log += 'Time {:.2f}s'.format(time() - start_time)
logger.info(log)
# save
is_best = (best_score is None or validation['acc'] > best_score)
best_score = validation['acc'] if is_best else best_score
state_dict = {
'model': target_cnn.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch_i,
'val/acc': best_score,
}
save(args.logdir, state_dict, is_best)
try:
best_score = None
for epoch_i in range(1, 1 + args.epochs):
start_time = time()
training = adversarial(
source_cnn, target_cnn, discriminator,
source_train_loader, target_train_loader,
criterion, criterion,
optimizer, d_optimizer,
args=args
)
validation = validate(
target_cnn, target_test_loader, criterion, args=args)
validation2 = validate(
target_cnn, target_train_loader, criterion, args=args)
log = 'Epoch {}/{} '.format(epoch_i, args.epochs)
log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format(
training['d/loss'], training['target/loss'])
log += '[Val] Target/Loss {:.3f} Target/Acc {:.3f} '.format(
validation['loss'], validation['acc'])
log += log_source
log += '[Train] Target/Loss {:.3f} Target/Acc {:.3f} '.format(
validation2['loss'], validation2['acc'])
log += 'Time {:.2f}s'.format(time() - start_time)
logger.info(log)
# save
is_best = (best_score is None or validation['acc'] > best_score)
best_score = validation['acc'] if is_best else best_score
state_dict = {
'model': target_cnn.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch_i,
'val/acc': best_score,
}
save(args.logdir, state_dict, is_best)
logger.info('Best val. acc.: {}'.format(best_score))
except KeyboardInterrupt as ke:
logger.info('\n============ Summary ============= \n')
logger.info('Best val. acc.: {}'.format(best_score))
def adversarial( def adversarial(

3
main.py

@ -10,7 +10,8 @@ if __name__ == '__main__':
parser.add_argument('--slope', type=float, default=0.2) parser.add_argument('--slope', type=float, default=0.2)
parser.add_argument('--model', type=str, default='default') parser.add_argument('--model', type=str, default='default')
# train # train
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--d_lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=2.5e-5) parser.add_argument('--weight_decay', type=float, default=2.5e-5)
parser.add_argument('--epochs', type=int, default=500) parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--batch_size', type=int, default=32)

50
models/resnet50off.py

@ -3,18 +3,54 @@ import torch.nn.functional as F
from torchvision import models from torchvision import models
class ResNet50Mod(nn.Module):
def __init__(self):
super(ResNet50Mod, self).__init__()
model_resnet50 = models.resnet50(pretrained=True)
self.freezed_rn50 = nn.Sequential(
model_resnet50.conv1,
model_resnet50.bn1,
model_resnet50.relu,
model_resnet50.maxpool,
model_resnet50.layer1,
model_resnet50.layer2,
model_resnet50.layer3,
)
self.layer4 = model_resnet50.layer4
self.avgpool = model_resnet50.avgpool
self.__in_features = model_resnet50.fc.in_features
def forward(self, x):
x = self.freezed_rn50(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, in_channels=3, h=256, dropout=0.5):
def __init__(self, in_channels=3, h=256, dropout=0.5, srcTrain=False):
super(Encoder, self).__init__() super(Encoder, self).__init__()
resnetModel = models.resnet50(pretrained=True)
feature_map = list(resnetModel.children())
feature_map.pop()
self.feature_extractor = nn.Sequential(*feature_map)
# resnetModel = models.resnet50(pretrained=True)
# feature_map = list(resnetModel.children())
# feature_map.pop()
# self.feature_extractor = nn.Sequential(*feature_map)
rnMod = ResNet50Mod()
self.feature_extractor = rnMod.freezed_rn50
self.layer4 = rnMod.layer4
self.avgpool = rnMod.avgpool
if srcTrain:
for param in self.feature_extractor.parameters():
param.requires_grad = False
def forward(self, x): def forward(self, x):
x = x.expand(x.data.shape[0], 3, 227, 227) x = x.expand(x.data.shape[0], 3, 227, 227)
x = self.feature_extractor(x) x = self.feature_extractor(x)
###
x = self.layer4(x)
x = self.avgpool(x)
###
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
return x return x
@ -34,9 +70,9 @@ class Classifier(nn.Module):
class CNN(nn.Module): class CNN(nn.Module):
def __init__(self, in_channels=3, n_classes=31, target=False):
def __init__(self, in_channels=3, n_classes=31, target=False, srcTrain=False):
super(CNN, self).__init__() super(CNN, self).__init__()
self.encoder = Encoder(in_channels=in_channels)
self.encoder = Encoder(in_channels=in_channels, srcTrain=srcTrain)
self.classifier = Classifier(n_classes) self.classifier = Classifier(n_classes)
if target: if target:
for param in self.classifier.parameters(): for param in self.classifier.parameters():

Loading…
Cancel
Save