Created
April 25, 2021 05:57
-
-
Save magesh-technovator/b3bcfb56d11609e9ac506ac109cc9ee7 to your computer and use it in GitHub Desktop.
CRAFT Architecture
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Copyright (c) 2019-present NAVER Corp. | |
MIT License | |
""" | |
# -*- coding: utf-8 -*- | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from basenet.vgg16_bn import vgg16_bn, init_weights | |
class double_conv(nn.Module): | |
def __init__(self, in_ch, mid_ch, out_ch): | |
super(double_conv, self).__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1), | |
nn.BatchNorm2d(mid_ch), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1), | |
nn.BatchNorm2d(out_ch), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, x): | |
x = self.conv(x) | |
return x | |
class CRAFT(nn.Module): | |
def __init__(self, pretrained=False, freeze=False): | |
super(CRAFT, self).__init__() | |
""" Base network """ | |
self.basenet = vgg16_bn(pretrained, freeze) | |
""" U network """ | |
self.upconv1 = double_conv(1024, 512, 256) | |
self.upconv2 = double_conv(512, 256, 128) | |
self.upconv3 = double_conv(256, 128, 64) | |
self.upconv4 = double_conv(128, 64, 32) | |
num_class = 2 | |
self.conv_cls = nn.Sequential( | |
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), | |
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), | |
nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True), | |
nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True), | |
nn.Conv2d(16, num_class, kernel_size=1), | |
) | |
init_weights(self.upconv1.modules()) | |
init_weights(self.upconv2.modules()) | |
init_weights(self.upconv3.modules()) | |
init_weights(self.upconv4.modules()) | |
init_weights(self.conv_cls.modules()) | |
def forward(self, x): | |
""" Base network """ | |
sources = self.basenet(x) | |
""" U network """ | |
y = torch.cat([sources[0], sources[1]], dim=1) | |
y = self.upconv1(y) | |
y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False) | |
y = torch.cat([y, sources[2]], dim=1) | |
y = self.upconv2(y) | |
y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False) | |
y = torch.cat([y, sources[3]], dim=1) | |
y = self.upconv3(y) | |
y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False) | |
y = torch.cat([y, sources[4]], dim=1) | |
feature = self.upconv4(y) | |
y = self.conv_cls(feature) | |
return y.permute(0,2,3,1), feature |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment