Created
December 24, 2018 04:58
-
-
Save soumith/f7e9afbdc561a2cfa9a2c5bdf443aa8b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.autograd import Variable | |
import torch.nn.functional as F | |
class simpnet_imgnet_drpall(nn.Module): | |
""" | |
args: classes | |
scale | |
network_idx (0,1):simpnet5m, simpnet8m | |
mode : stride mode (1,2,3,4,5) | |
""" | |
def __init__(self, classes=1000, scale=1.0, network_idx=0, mode=1, simpnet_name='simpnet_imgnet_drpall'): | |
super(simpnet_imgnet_drpall, self).__init__() | |
self.cfg = { | |
'simpnet5m': [['C', 66], ['C', 128], ['C', 128], ['C', 128], ['C', 192], ['C', 192], ['C', 192], ['C', 192], ['C', 192], ['C', 288], ['P'], ['C', 288], ['C', 355], ['C', 432]], | |
'simpnet8m': [['C', 128], ['C', 182], ['C', 182], ['C', 182], ['C', 182], ['C', 182], ['C', 182], ['C', 182], ['C', 182], ['C', 430], ['P'], ['C', 430], ['C', 455], ['C', 600]]} | |
self.scale = scale | |
self.networks = ['simpnet5m', 'simpnet8m'] | |
self.network_idx = network_idx | |
self.mode = mode | |
self.strides = {1: [2, 2, 2, 1, 1], #s1 | |
2: [2, 2, 1, 2, 1, 1], #s4 | |
3: [2, 2, 1, 1, 2, 1], #s3 | |
4: [2, 1, 2, 1, 2, 1], #s5 | |
5: [2, 1, 2, 1, 2, 1, 1]}#s6 | |
self.features = self._make_layers(scale) | |
self.classifier = nn.Linear(round(self.cfg[self.networks[network_idx]][-1][1] * scale), classes) | |
def load_my_state_dict(self, state_dict): | |
own_state = self.state_dict() | |
for name, param in state_dict.items(): | |
name = name.replace('module.', '') | |
if name not in own_state: | |
continue | |
if isinstance(param, Parameter): | |
# backwards compatibility for serialized parameters | |
param = param.data | |
print("STATE_DICT: {}".format(name)) | |
try: | |
own_state[name].copy_(param) | |
except: | |
print('While copying the parameter named {}, whose dimensions in the model are' | |
' {} and whose dimensions in the checkpoint are {}, ... Using Initial Params'.format( | |
name, own_state[name].size(), param.size())) | |
def forward(self, x): | |
out = self.features(x) | |
#Global Max Pooling | |
out = F.max_pool2d(out, kernel_size=out.size()[2:]) | |
out = F.dropout2d(out, 0.01, training=False) | |
out = out.view(out.size(0), -1) | |
out = self.classifier(out) | |
return out | |
def _make_layers(self, scale): | |
layers = [] | |
input_channel = 3 | |
idx = 0 | |
for x in self.cfg[self.networks[self.network_idx]]: | |
if idx == len(self.strides[self.mode]) or x[0] == 'P': | |
layers += [nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False), | |
nn.Dropout2d(p=0.00)] | |
if x[0] != 'C': | |
continue | |
filters = round(x[1] * scale) | |
if idx < len(self.strides[self.mode]): | |
stride = self.strides[self.mode][idx] | |
else: | |
stride = 1 | |
if idx in (len(self.strides[self.mode])-1, 9, 12): | |
layers += [nn.Conv2d(input_channel, filters, kernel_size=[3, 3], stride=(stride, stride), padding=(1, 1)), | |
nn.BatchNorm2d(filters, eps=1e-05, momentum=0.05, affine=True), | |
nn.ReLU(inplace=True)] | |
else: | |
layers += [nn.Conv2d(input_channel, filters, kernel_size=[3, 3], stride=(stride, stride), padding=(1, 1)), | |
nn.BatchNorm2d(filters, eps=1e-05, momentum=0.05, affine=True), | |
nn.ReLU(inplace=True), | |
nn.Dropout2d(p=0.000)] | |
input_channel = filters | |
idx += 1 | |
model = nn.Sequential(*layers) | |
print(model) | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu')) | |
return model | |
model = simpnet_imgnet_drpall() | |
model = nn.DataParallel(model).cuda() | |
x = torch.randn(10, 3, 224, 224).cuda() | |
y = model(x) | |
y.sum().backward() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment