Created
January 27, 2023 21:55
-
-
Save cenkbircanoglu/935ca8714bc38299d65a41c7bcd9d21a to your computer and use it in GitHub Desktop.
unet model to use in isim
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
from models.unet.net import Net | |
class CAM(Net): | |
def __init__(self, *args, **kwargs): | |
super(CAM, self).__init__(*args, **kwargs) | |
def forward(self, x): | |
return self.forward_cam(x) | |
def forward_cam(self, x): | |
enc1 = self.encoder1(x) | |
enc2 = self.encoder2(self.pool1(enc1)) | |
enc3 = self.encoder3(self.pool2(enc2)) | |
enc4 = self.encoder4(self.pool3(enc3)) | |
bottleneck = self.bottleneck(self.pool4(enc4)) | |
weights = torch.zeros_like(self.classifier.weight) | |
with torch.no_grad(): | |
weights.set_(self.classifier.weight.detach()) | |
x = F.relu(F.conv2d(bottleneck, weight=weights)) | |
x = x[0] + x[1].flip(-1) | |
return x | |
if __name__ == '__main__': | |
from torchsummary import summary | |
model = CAM() | |
summary(model, input_size=(3, 320, 320)) | |
x = torch.rand([2, 3, 320, 320]) | |
y = model(x) | |
print(y.shape) | |
assert y.shape == (20, 20, 20) | |
model = CAM(init_features=32) | |
summary(model, input_size=(3, 320, 320)) | |
x = torch.rand([2, 3, 320, 320]) | |
y = model(x) | |
assert y.shape == (20, 20, 20) | |
model = CAM(mid_ch=64) | |
summary(model, input_size=(3, 320, 320)) | |
x = torch.rand([2, 3, 320, 320]) | |
y = model(x) | |
assert y.shape == (20, 20, 20) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
from models.layers import gap2d | |
def _block(in_channels, features, name): | |
return nn.Sequential( | |
OrderedDict( | |
[ | |
( | |
name + "conv1", | |
nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=features, | |
kernel_size=3, | |
padding=1, | |
bias=False, | |
), | |
), | |
(name + "norm1", nn.BatchNorm2d(num_features=features)), | |
(name + "relu1", nn.ReLU(inplace=True)), | |
( | |
name + "conv2", | |
nn.Conv2d( | |
in_channels=features, | |
out_channels=features, | |
kernel_size=3, | |
padding=1, | |
bias=False, | |
), | |
), | |
(name + "norm2", nn.BatchNorm2d(num_features=features)), | |
(name + "relu2", nn.ReLU(inplace=True)), | |
] | |
) | |
) | |
class Net(nn.Module): | |
def __init__(self, in_ch=3, mid_ch=32, out_ch=1, num_classes=20, *args, **kwargs): | |
super(Net, self).__init__() | |
self.out_channels = out_ch | |
self.num_classes = num_classes | |
self.encoder1 = _block(in_ch, mid_ch, name="enc1") | |
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.encoder2 = _block(mid_ch, mid_ch * 2, name="enc2") | |
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.encoder3 = _block(mid_ch * 2, mid_ch * 4, name="enc3") | |
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.encoder4 = _block(mid_ch * 4, mid_ch * 8, name="enc4") | |
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.bottleneck = _block(mid_ch * 8, mid_ch * 16, name="bottleneck") | |
self.upconv4 = nn.ConvTranspose2d( | |
mid_ch * 16, mid_ch * 8, kernel_size=2, stride=2 | |
) | |
self.decoder4 = _block((mid_ch * 8) * 2, mid_ch * 8, name="dec4") | |
self.upconv3 = nn.ConvTranspose2d( | |
mid_ch * 8, mid_ch * 4, kernel_size=2, stride=2 | |
) | |
self.decoder3 = _block((mid_ch * 4) * 2, mid_ch * 4, name="dec3") | |
self.upconv2 = nn.ConvTranspose2d( | |
mid_ch * 4, mid_ch * 2, kernel_size=2, stride=2 | |
) | |
self.decoder2 = _block((mid_ch * 2) * 2, mid_ch * 2, name="dec2") | |
self.upconv1 = nn.ConvTranspose2d( | |
mid_ch * 2, mid_ch, kernel_size=2, stride=2 | |
) | |
self.decoder1 = _block(mid_ch * 2, mid_ch, name="dec1") | |
self.conv = nn.Conv2d( | |
in_channels=mid_ch, out_channels=out_ch, kernel_size=1 | |
) | |
self.classifier = nn.Conv2d(mid_ch * 16, num_classes, 1, bias=False) | |
self.encoder_modules = nn.ModuleList( | |
[self.encoder1, self.pool1, self.encoder2, self.pool2, self.encoder3, self.pool3, self.encoder4, self.pool4, | |
self.bottleneck]) | |
self.decoder_modules = nn.ModuleList( | |
[self.upconv4, self.decoder4, self.upconv3, self.decoder3, self.upconv2, self.decoder2, self.upconv1, | |
self.decoder1, self.conv]) | |
self.classifier_modules = nn.ModuleList([self.classifier]) | |
def forward(self, x): | |
enc1 = self.encoder1(x) | |
enc2 = self.encoder2(self.pool1(enc1)) | |
enc3 = self.encoder3(self.pool2(enc2)) | |
enc4 = self.encoder4(self.pool3(enc3)) | |
bottleneck = self.bottleneck(self.pool4(enc4)) | |
cls_label_pred = gap2d(bottleneck, keepdims=True) | |
cls_label_pred = self.classifier(cls_label_pred) | |
cls_label_pred = cls_label_pred.view(-1, self.num_classes) | |
return cls_label_pred | |
def trainable_parameters(self): | |
return (list(self.encoder_modules.parameters()), list(self.classifier_modules.parameters())) | |
if __name__ == '__main__': | |
from torchsummary import summary | |
model = Net() | |
summary(model, input_size=(3, 320, 320)) | |
x = torch.rand([2, 3, 320, 320]) | |
y = model(x) | |
print(y.shape) | |
assert y.shape == (2, 20) | |
model = Net(mid_ch=32) | |
summary(model, input_size=(3, 320, 320)) | |
x = torch.rand([2, 3, 320, 320]) | |
y = model(x) | |
assert y.shape == (2, 20) | |
model = Net(mid_ch=64) | |
summary(model, input_size=(3, 320, 320)) | |
x = torch.rand([2, 3, 320, 320]) | |
y = model(x) | |
assert y.shape == (2, 20) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from models.generate_label import generate_pseudo_label | |
from models.layers import gap2d | |
from models.unet.cam import CAM | |
from utils import count_parameters | |
class Segmentation(CAM): | |
def __init__(self, *args, **kwargs): | |
super(Segmentation, self).__init__(*args, **kwargs) | |
def forward(self, x): | |
enc1 = self.encoder1(x) | |
enc2 = self.encoder2(self.pool1(enc1)) | |
enc3 = self.encoder3(self.pool2(enc2)) | |
enc4 = self.encoder4(self.pool3(enc3)) | |
bottleneck = self.bottleneck(self.pool4(enc4)) | |
cls_label_pred = gap2d(bottleneck, keepdims=True) | |
cls_label_pred = self.classifier(cls_label_pred) | |
cls_label_pred = cls_label_pred.view(-1, self.num_classes) | |
dec4 = self.upconv4(bottleneck) | |
dec4 = torch.cat((dec4, enc4), dim=1) | |
dec4 = self.decoder4(dec4) | |
dec3 = self.upconv3(dec4) | |
dec3 = torch.cat((dec3, enc3), dim=1) | |
dec3 = self.decoder3(dec3) | |
dec2 = self.upconv2(dec3) | |
dec2 = torch.cat((dec2, enc2), dim=1) | |
dec2 = self.decoder2(dec2) | |
dec1 = self.upconv1(dec2) | |
dec1 = torch.cat((dec1, enc1), dim=1) | |
dec1 = self.decoder1(dec1) | |
seg_label_pred = self.conv(dec1) | |
return cls_label_pred, seg_label_pred | |
def trainable_parameters(self): | |
return (list(self.encoder_modules.parameters()), list(self.classifier_modules.parameters()), | |
list(self.decoder_modules.parameters())) | |
if __name__ == '__main__': | |
from torchsummary import summary | |
model = Segmentation(in_ch=3, out_ch=21, mid_ch=32, num_classes=20) | |
summary(model, input_size=(3, 320, 320)) | |
x = torch.rand([2, 3, 320, 320]) | |
cls_pred, seg_pred = model(x) | |
assert cls_pred.shape == (2, 20) | |
assert seg_pred.shape == (2, 21, 320, 320) | |
model = Segmentation(mid_ch=32) | |
summary(model, input_size=(3, 320, 320)) | |
x = torch.rand([2, 3, 320, 320]) | |
cls_pred, seg_pred = model(x) | |
assert cls_pred.shape == (2, 20) | |
assert seg_pred.shape == (2, 1, 320, 320) | |
model = Segmentation(mid_ch=64) | |
summary(model, input_size=(3, 320, 320)) | |
x = torch.rand([2, 3, 320, 320]) | |
cls_pred, seg_pred = model(x) | |
assert cls_pred.shape == (2, 20) | |
assert seg_pred.shape == (2, 1, 320, 320) | |
## Test Generating PSEUDO Labels | |
imgs = torch.rand([1, 1, 2, 3, 320, 320]) | |
cam, keys = generate_pseudo_label(model, imgs, torch.Tensor([1, 0, 0, 1, 0, 0, 0]), (512, 512)) | |
assert cam.shape == (512, 512) | |
count1 = count_parameters(model) | |
count2 = 0 | |
for p in model.trainable_parameters(): | |
for pp in p: | |
if pp.requires_grad: | |
count2 += pp.numel() | |
assert count1 == count2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment