import torch import os import torch.nn as nn import torch.utils.model_zoo as model_zoo __all__ = ['AlexNet', 'alexnet'] class LRN(nn.Module): def __init__(self, local_size=1, alpha=1.0, beta=0.75, ACROSS_CHANNELS=True): super(LRN, self).__init__() self.ACROSS_CHANNELS = ACROSS_CHANNELS if ACROSS_CHANNELS: self.average = nn.AvgPool3d(kernel_size=(local_size, 1, 1), stride=1, padding=(int((local_size-1.0)/2), 0, 0)) else: self.average = nn.AvgPool2d(kernel_size=local_size, stride=1, padding=int((local_size-1.0)/2)) self.alpha = alpha self.beta = beta def forward(self, x): if self.ACROSS_CHANNELS: div = x.pow(2).unsqueeze(1) div = self.average(div).squeeze(1) div = div.mul(self.alpha).add(1.0).pow(self.beta) else: div = x.pow(2) div = self.average(div) div = div.mul(self.alpha).add(1.0).pow(self.beta) x = x.div(div) return x class AlexNet(nn.Module): def __init__(self, num_classes=1000): super(AlexNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0), nn.ReLU(inplace=True), LRN(local_size=5, alpha=0.0001, beta=0.75), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2), nn.ReLU(inplace=True), LRN(local_size=5, alpha=0.0001, beta=0.75), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2), nn.ReLU(inplace=True), nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), ) self.classifier = nn.Sequential( nn.Linear(256 * 6 * 6, 4096), nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(4096, num_classes), ) def forward(self, x): x = self.features(x) x = x.view(x.size(0), 256 * 6 * 6) x = self.classifier(x) return x def alexnet(pretrained=False, **kwargs): r"""AlexNet model architecture from the `"One weird trick..." `_ paper. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = AlexNet(**kwargs) if pretrained: model_path = '/home/wogong/Models/alexnet.pth.tar' pretrained_model = torch.load(model_path) model.load_state_dict(pretrained_model['state_dict']) return model