From 06baf211975b331f827534b2a250dd26201afbaf Mon Sep 17 00:00:00 2001 From: fazilaltinel Date: Tue, 15 Dec 2020 14:57:37 +0300 Subject: [PATCH] ResNet changes for Office dataset --- README.md | 6 +- core/experiment_rn50.py | 6 +- core/train_source_rn50.py | 15 +++-- core/trainer.py | 130 ++++++++++++++++++++------------------ main.py | 3 +- models/resnet50off.py | 50 +++++++++++++-- 6 files changed, 134 insertions(+), 76 deletions(-) diff --git a/README.md b/README.md index 49d682c..605710b 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # ADDA.PyTorch-resnet -Implementation of "Adversarial Discriminative Domain Adapation" in PyTorch +Implementation of "Adversarial Discriminative Domain Adaptation" in PyTorch This repo is mostly based on https://github.com/Fujiki-Nakamura/ADDA.PyTorch @@ -19,8 +19,8 @@ $ python main.py --logdir outputs --trained outputs/best_model.pt --slope 0.2 For training on Office dataset using ResNet-50 ``` -$ python core/train_source_rn50.py --n_classes 31 --logdir outputs -$ python main.py --n_classes 31 --trained outputs/best_model.pt --logdir outputs --model resnet50 --src-cat amazon --tgt-cat webcam +$ python core/train_source_rn50.py --n_classes 31 --lr 1e-4 --src_cat amazon --tgt_cat webcam +$ python main.py --n_classes 31 --trained outputs/garbage/best_model.pt --lr 1e-5 --d_lr 1e-4 --logdir outputs --model resnet50 --src-cat amazon --tgt-cat webcam ``` ## Result diff --git a/core/experiment_rn50.py b/core/experiment_rn50.py index c4a57cd..2c882cc 100644 --- a/core/experiment_rn50.py +++ b/core/experiment_rn50.py @@ -33,6 +33,10 @@ def run(args): # train target CNN target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device) target_cnn.load_state_dict(source_cnn.state_dict()) + for param in source_cnn.parameters(): + param.requires_grad = False + for param in target_cnn.classifier.parameters(): + param.requires_grad = False discriminator = Discriminator(args=args).to(args.device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam( @@ -40,7 +44,7 @@ def run(args): lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) d_optimizer = optim.Adam( discriminator.parameters(), - lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) + lr=args.d_lr, betas=args.betas, weight_decay=args.weight_decay) train_target_cnn( source_cnn, target_cnn, discriminator, criterion, optimizer, d_optimizer, diff --git a/core/train_source_rn50.py b/core/train_source_rn50.py index 4a21077..db3ebc1 100644 --- a/core/train_source_rn50.py +++ b/core/train_source_rn50.py @@ -21,15 +21,21 @@ def main(args): # data loaders dataset_root = os.environ["DATASETDIR"] source_loader = get_office(dataset_root, args.batch_size, args.src_cat) + target_loader = get_office(dataset_root, args.batch_size, args.tgt_cat) # train source CNN - source_cnn = CNN(in_channels=args.in_channels).to(args.device) + source_cnn = CNN(in_channels=args.in_channels, srcTrain=True).to(args.device) + # for param in source_cnn.encoder.parameters(): + # param.requires_grad = False criterion = nn.CrossEntropyLoss() + # optimizer = optim.Adam( + # source_cnn.classifier.parameters(), + # lr=args.lr, weight_decay=args.weight_decay) optimizer = optim.Adam( source_cnn.parameters(), lr=args.lr, weight_decay=args.weight_decay) source_cnn = train_source_cnn( - source_cnn, source_loader, source_loader, + source_cnn, source_loader, target_loader, criterion, optimizer, args=args) @@ -41,9 +47,9 @@ if __name__ == '__main__': parser.add_argument('--trained', type=str, default='') parser.add_argument('--slope', type=float, default=0.2) # train - parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--weight_decay', type=float, default=2.5e-5) - parser.add_argument('--epochs', type=int, default=50) + parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--batch_size', type=int, default=32) # misc parser.add_argument('--device', type=str, default='cuda:0') @@ -52,5 +58,6 @@ if __name__ == '__main__': parser.add_argument('--message', '-m', type=str, default='') # office dataset categories parser.add_argument('--src_cat', type=str, default='amazon') + parser.add_argument('--tgt_cat', type=str, default='webcam') args, unknown = parser.parse_known_args() main(args) diff --git a/core/trainer.py b/core/trainer.py index ca1fe55..b5fb7d7 100644 --- a/core/trainer.py +++ b/core/trainer.py @@ -16,31 +16,36 @@ def train_source_cnn( source_cnn, train_loader, test_loader, criterion, optimizer, args=None ): - best_score = None - for epoch_i in range(1, 1 + args.epochs): - start_time = time() - training = train( - source_cnn, train_loader, criterion, optimizer, args=args) - validation = validate( - source_cnn, test_loader, criterion, args=args) - log = 'Epoch {}/{} '.format(epoch_i, args.epochs) - log += '| Train/Loss {:.3f} Acc {:.3f} '.format( - training['loss'], training['acc']) - log += '| Val/Loss {:.3f} Acc {:.3f} '.format( - validation['loss'], validation['acc']) - log += 'Time {:.2f}s'.format(time() - start_time) - logger.info(log) - - # save - is_best = (best_score is None or validation['acc'] > best_score) - best_score = validation['acc'] if is_best else best_score - state_dict = { - 'model': source_cnn.state_dict(), - 'optimizer': optimizer.state_dict(), - 'epoch': epoch_i, - 'val/acc': best_score, - } - save(args.logdir, state_dict, is_best) + try: + best_score = None + for epoch_i in range(1, 1 + args.epochs): + start_time = time() + training = train( + source_cnn, train_loader, criterion, optimizer, args=args) + validation = validate( + source_cnn, test_loader, criterion, args=args) + log = 'Epoch {}/{} '.format(epoch_i, args.epochs) + log += '| Train/Loss {:.3f} Acc {:.3f} '.format( + training['loss'], training['acc']) + log += '| Val/Loss {:.3f} Acc {:.3f} '.format( + validation['loss'], validation['acc']) + log += 'Time {:.2f}s'.format(time() - start_time) + logger.info(log) + + # save + is_best = (best_score is None or validation['acc'] > best_score) + best_score = validation['acc'] if is_best else best_score + state_dict = { + 'model': source_cnn.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch_i, + 'val/acc': best_score, + } + save(args.logdir, state_dict, is_best) + logger.info('Best val. acc.: {}'.format(best_score)) + except KeyboardInterrupt as ke: + logger.info('\n============ Summary ============= \n') + logger.info('Best val. acc.: {}'.format(best_score)) return source_cnn @@ -54,41 +59,46 @@ def train_target_cnn( validation = validate(source_cnn, target_test_loader, criterion, args=args) log_source = 'Source/Acc {:.3f} '.format(validation['acc']) - best_score = None - for epoch_i in range(1, 1 + args.epochs): - start_time = time() - training = adversarial( - source_cnn, target_cnn, discriminator, - source_train_loader, target_train_loader, - criterion, criterion, - optimizer, d_optimizer, - args=args - ) - validation = validate( - target_cnn, target_test_loader, criterion, args=args) - validation2 = validate( - target_cnn, target_train_loader, criterion, args=args) - log = 'Epoch {}/{} '.format(epoch_i, args.epochs) - log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format( - training['d/loss'], training['target/loss']) - log += '[Val] Target/Loss {:.3f} Target/Acc {:.3f} '.format( - validation['loss'], validation['acc']) - log += log_source - log += '[Train] Target/Loss {:.3f} Target/Acc {:.3f} '.format( - validation2['loss'], validation2['acc']) - log += 'Time {:.2f}s'.format(time() - start_time) - logger.info(log) - - # save - is_best = (best_score is None or validation['acc'] > best_score) - best_score = validation['acc'] if is_best else best_score - state_dict = { - 'model': target_cnn.state_dict(), - 'optimizer': optimizer.state_dict(), - 'epoch': epoch_i, - 'val/acc': best_score, - } - save(args.logdir, state_dict, is_best) + try: + best_score = None + for epoch_i in range(1, 1 + args.epochs): + start_time = time() + training = adversarial( + source_cnn, target_cnn, discriminator, + source_train_loader, target_train_loader, + criterion, criterion, + optimizer, d_optimizer, + args=args + ) + validation = validate( + target_cnn, target_test_loader, criterion, args=args) + validation2 = validate( + target_cnn, target_train_loader, criterion, args=args) + log = 'Epoch {}/{} '.format(epoch_i, args.epochs) + log += 'D/Loss {:.3f} Target/Loss {:.3f} '.format( + training['d/loss'], training['target/loss']) + log += '[Val] Target/Loss {:.3f} Target/Acc {:.3f} '.format( + validation['loss'], validation['acc']) + log += log_source + log += '[Train] Target/Loss {:.3f} Target/Acc {:.3f} '.format( + validation2['loss'], validation2['acc']) + log += 'Time {:.2f}s'.format(time() - start_time) + logger.info(log) + + # save + is_best = (best_score is None or validation['acc'] > best_score) + best_score = validation['acc'] if is_best else best_score + state_dict = { + 'model': target_cnn.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch_i, + 'val/acc': best_score, + } + save(args.logdir, state_dict, is_best) + logger.info('Best val. acc.: {}'.format(best_score)) + except KeyboardInterrupt as ke: + logger.info('\n============ Summary ============= \n') + logger.info('Best val. acc.: {}'.format(best_score)) def adversarial( diff --git a/main.py b/main.py index f2ec273..48a5e28 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,8 @@ if __name__ == '__main__': parser.add_argument('--slope', type=float, default=0.2) parser.add_argument('--model', type=str, default='default') # train - parser.add_argument('--lr', type=float, default=2e-4) + parser.add_argument('--lr', type=float, default=1e-5) + parser.add_argument('--d_lr', type=float, default=1e-3) parser.add_argument('--weight_decay', type=float, default=2.5e-5) parser.add_argument('--epochs', type=int, default=500) parser.add_argument('--batch_size', type=int, default=32) diff --git a/models/resnet50off.py b/models/resnet50off.py index aac8d84..a8a51bb 100644 --- a/models/resnet50off.py +++ b/models/resnet50off.py @@ -3,18 +3,54 @@ import torch.nn.functional as F from torchvision import models +class ResNet50Mod(nn.Module): + def __init__(self): + super(ResNet50Mod, self).__init__() + model_resnet50 = models.resnet50(pretrained=True) + self.freezed_rn50 = nn.Sequential( + model_resnet50.conv1, + model_resnet50.bn1, + model_resnet50.relu, + model_resnet50.maxpool, + model_resnet50.layer1, + model_resnet50.layer2, + model_resnet50.layer3, + ) + self.layer4 = model_resnet50.layer4 + self.avgpool = model_resnet50.avgpool + self.__in_features = model_resnet50.fc.in_features + + def forward(self, x): + x = self.freezed_rn50(x) + x = self.layer4(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + return x + + class Encoder(nn.Module): - def __init__(self, in_channels=3, h=256, dropout=0.5): + def __init__(self, in_channels=3, h=256, dropout=0.5, srcTrain=False): super(Encoder, self).__init__() - resnetModel = models.resnet50(pretrained=True) - feature_map = list(resnetModel.children()) - feature_map.pop() - self.feature_extractor = nn.Sequential(*feature_map) + # resnetModel = models.resnet50(pretrained=True) + # feature_map = list(resnetModel.children()) + # feature_map.pop() + # self.feature_extractor = nn.Sequential(*feature_map) + rnMod = ResNet50Mod() + self.feature_extractor = rnMod.freezed_rn50 + self.layer4 = rnMod.layer4 + self.avgpool = rnMod.avgpool + if srcTrain: + for param in self.feature_extractor.parameters(): + param.requires_grad = False def forward(self, x): x = x.expand(x.data.shape[0], 3, 227, 227) x = self.feature_extractor(x) + ### + x = self.layer4(x) + x = self.avgpool(x) + ### x = x.view(x.size(0), -1) return x @@ -34,9 +70,9 @@ class Classifier(nn.Module): class CNN(nn.Module): - def __init__(self, in_channels=3, n_classes=31, target=False): + def __init__(self, in_channels=3, n_classes=31, target=False, srcTrain=False): super(CNN, self).__init__() - self.encoder = Encoder(in_channels=in_channels) + self.encoder = Encoder(in_channels=in_channels, srcTrain=srcTrain) self.classifier = Classifier(n_classes) if target: for param in self.classifier.parameters():