wogong
7 years ago
5 changed files with 123 additions and 23 deletions
@ -0,0 +1,88 @@ |
|||
import torch |
|||
import os |
|||
import torch.nn as nn |
|||
import torch.utils.model_zoo as model_zoo |
|||
|
|||
|
|||
__all__ = ['AlexNet', 'alexnet'] |
|||
|
|||
|
|||
class LRN(nn.Module): |
|||
def __init__(self, local_size=1, alpha=1.0, beta=0.75, ACROSS_CHANNELS=True): |
|||
super(LRN, self).__init__() |
|||
self.ACROSS_CHANNELS = ACROSS_CHANNELS |
|||
if ACROSS_CHANNELS: |
|||
self.average = nn.AvgPool3d(kernel_size=(local_size, 1, 1), |
|||
stride=1, |
|||
padding=(int((local_size-1.0)/2), 0, 0)) |
|||
else: |
|||
self.average = nn.AvgPool2d(kernel_size=local_size, |
|||
stride=1, |
|||
padding=int((local_size-1.0)/2)) |
|||
self.alpha = alpha |
|||
self.beta = beta |
|||
|
|||
def forward(self, x): |
|||
if self.ACROSS_CHANNELS: |
|||
div = x.pow(2).unsqueeze(1) |
|||
div = self.average(div).squeeze(1) |
|||
div = div.mul(self.alpha).add(1.0).pow(self.beta) |
|||
else: |
|||
div = x.pow(2) |
|||
div = self.average(div) |
|||
div = div.mul(self.alpha).add(1.0).pow(self.beta) |
|||
x = x.div(div) |
|||
return x |
|||
|
|||
|
|||
class AlexNet(nn.Module): |
|||
|
|||
def __init__(self, num_classes=1000): |
|||
super(AlexNet, self).__init__() |
|||
self.features = nn.Sequential( |
|||
nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0), |
|||
nn.ReLU(inplace=True), |
|||
LRN(local_size=5, alpha=0.0001, beta=0.75), |
|||
nn.MaxPool2d(kernel_size=3, stride=2), |
|||
nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2), |
|||
nn.ReLU(inplace=True), |
|||
LRN(local_size=5, alpha=0.0001, beta=0.75), |
|||
nn.MaxPool2d(kernel_size=3, stride=2), |
|||
nn.Conv2d(256, 384, kernel_size=3, padding=1), |
|||
nn.ReLU(inplace=True), |
|||
nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2), |
|||
nn.ReLU(inplace=True), |
|||
nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2), |
|||
nn.ReLU(inplace=True), |
|||
nn.MaxPool2d(kernel_size=3, stride=2), |
|||
) |
|||
self.classifier = nn.Sequential( |
|||
nn.Linear(256 * 6 * 6, 4096), |
|||
nn.ReLU(inplace=True), |
|||
nn.Dropout(), |
|||
nn.Linear(4096, 4096), |
|||
nn.ReLU(inplace=True), |
|||
nn.Dropout(), |
|||
nn.Linear(4096, num_classes), |
|||
) |
|||
|
|||
def forward(self, x): |
|||
x = self.features(x) |
|||
x = x.view(x.size(0), 256 * 6 * 6) |
|||
x = self.classifier(x) |
|||
return x |
|||
|
|||
|
|||
def alexnet(pretrained=False, **kwargs): |
|||
r"""AlexNet model architecture from the |
|||
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper. |
|||
|
|||
Args: |
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet |
|||
""" |
|||
model = AlexNet(**kwargs) |
|||
if pretrained: |
|||
model_path = '/home/wogong/Models/alexnet.pth.tar' |
|||
pretrained_model = torch.load(model_path) |
|||
model.load_state_dict(pretrained_model['state_dict']) |
|||
return model |
Loading…
Reference in new issue