wogong
7 years ago
5 changed files with 123 additions and 23 deletions
@ -0,0 +1,88 @@ |
|||||
|
import torch |
||||
|
import os |
||||
|
import torch.nn as nn |
||||
|
import torch.utils.model_zoo as model_zoo |
||||
|
|
||||
|
|
||||
|
__all__ = ['AlexNet', 'alexnet'] |
||||
|
|
||||
|
|
||||
|
class LRN(nn.Module): |
||||
|
def __init__(self, local_size=1, alpha=1.0, beta=0.75, ACROSS_CHANNELS=True): |
||||
|
super(LRN, self).__init__() |
||||
|
self.ACROSS_CHANNELS = ACROSS_CHANNELS |
||||
|
if ACROSS_CHANNELS: |
||||
|
self.average = nn.AvgPool3d(kernel_size=(local_size, 1, 1), |
||||
|
stride=1, |
||||
|
padding=(int((local_size-1.0)/2), 0, 0)) |
||||
|
else: |
||||
|
self.average = nn.AvgPool2d(kernel_size=local_size, |
||||
|
stride=1, |
||||
|
padding=int((local_size-1.0)/2)) |
||||
|
self.alpha = alpha |
||||
|
self.beta = beta |
||||
|
|
||||
|
def forward(self, x): |
||||
|
if self.ACROSS_CHANNELS: |
||||
|
div = x.pow(2).unsqueeze(1) |
||||
|
div = self.average(div).squeeze(1) |
||||
|
div = div.mul(self.alpha).add(1.0).pow(self.beta) |
||||
|
else: |
||||
|
div = x.pow(2) |
||||
|
div = self.average(div) |
||||
|
div = div.mul(self.alpha).add(1.0).pow(self.beta) |
||||
|
x = x.div(div) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class AlexNet(nn.Module): |
||||
|
|
||||
|
def __init__(self, num_classes=1000): |
||||
|
super(AlexNet, self).__init__() |
||||
|
self.features = nn.Sequential( |
||||
|
nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0), |
||||
|
nn.ReLU(inplace=True), |
||||
|
LRN(local_size=5, alpha=0.0001, beta=0.75), |
||||
|
nn.MaxPool2d(kernel_size=3, stride=2), |
||||
|
nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2), |
||||
|
nn.ReLU(inplace=True), |
||||
|
LRN(local_size=5, alpha=0.0001, beta=0.75), |
||||
|
nn.MaxPool2d(kernel_size=3, stride=2), |
||||
|
nn.Conv2d(256, 384, kernel_size=3, padding=1), |
||||
|
nn.ReLU(inplace=True), |
||||
|
nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2), |
||||
|
nn.ReLU(inplace=True), |
||||
|
nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2), |
||||
|
nn.ReLU(inplace=True), |
||||
|
nn.MaxPool2d(kernel_size=3, stride=2), |
||||
|
) |
||||
|
self.classifier = nn.Sequential( |
||||
|
nn.Linear(256 * 6 * 6, 4096), |
||||
|
nn.ReLU(inplace=True), |
||||
|
nn.Dropout(), |
||||
|
nn.Linear(4096, 4096), |
||||
|
nn.ReLU(inplace=True), |
||||
|
nn.Dropout(), |
||||
|
nn.Linear(4096, num_classes), |
||||
|
) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
x = self.features(x) |
||||
|
x = x.view(x.size(0), 256 * 6 * 6) |
||||
|
x = self.classifier(x) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
def alexnet(pretrained=False, **kwargs): |
||||
|
r"""AlexNet model architecture from the |
||||
|
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper. |
||||
|
|
||||
|
Args: |
||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet |
||||
|
""" |
||||
|
model = AlexNet(**kwargs) |
||||
|
if pretrained: |
||||
|
model_path = '/home/wogong/Models/alexnet.pth.tar' |
||||
|
pretrained_model = torch.load(model_path) |
||||
|
model.load_state_dict(pretrained_model['state_dict']) |
||||
|
return model |
Loading…
Reference in new issue