Created
August 21, 2020 17:56
ResUnet.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "ResUnet.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyPGw/X+XTkEoWN4a6DxKTyZ", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/rishikksh20/843039a38ab8f770699dfca90c2718e7/resunet.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "S5AjjEi-mwpl", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Cq5vCdbmJ-vB", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"class ResidualConv(nn.Module):\n", | |
"\n", | |
" def __init__(self, input_dim, output_dim, stride, padding):\n", | |
" super(ResidualConv, self).__init__()\n", | |
"\n", | |
" self.conv_block = nn.Sequential(nn.BatchNorm2d(input_dim),\n", | |
" nn.ReLU(),\n", | |
" nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=padding),\n", | |
" nn.BatchNorm2d(output_dim),\n", | |
" nn.ReLU(),\n", | |
" nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),\n", | |
" )\n", | |
" self.conv_skip = nn.Sequential(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),\n", | |
" nn.BatchNorm2d(output_dim))\n", | |
"\n", | |
" def forward(self, x):\n", | |
"\n", | |
" return self.conv_block(x) + self.conv_skip(x)" | |
], | |
"execution_count": 96, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "wr_EplHHjErx", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "4ded4b26-fb5f-4d20-a1d3-83f0dbb4fe03" | |
}, | |
"source": [ | |
"# Test ResidualConv Block\n", | |
"x = torch.ones(1, 64, 224, 224)\n", | |
"res_conv = ResidualConv(64, 128, 2, 1)\n", | |
"res_conv(x).shape" | |
], | |
"execution_count": 106, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([1, 128, 112, 112])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 106 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "O0LwOXx-L4kc", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"class Upsample(nn.Module):\n", | |
"\n", | |
" def __init__(self, input_dim, output_dim, kernel, stride):\n", | |
" super(Upsample, self).__init__()\n", | |
"\n", | |
" self.upsample = nn.ConvTranspose2d(input_dim, output_dim, kernel_size=kernel, stride=stride)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" return self.upsample(x)" | |
], | |
"execution_count": 45, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "LjEZYML9PoxD", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "6c578426-0d11-458d-afe8-201d670f9c01" | |
}, | |
"source": [ | |
"# Test UpSample Block\n", | |
"x = torch.ones(1, 512, 28, 28)\n", | |
"upsample = Upsample(512, 512, 2, 2)\n", | |
"upsample(x).shape" | |
], | |
"execution_count": 46, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([1, 512, 56, 56])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 46 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "CTQ-lYABm42K", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"class ResUnet(nn.Module):\n", | |
"\n", | |
" def __init__(self, channel, dim):\n", | |
" super(ResUnet, self).__init__()\n", | |
"\n", | |
" self.input_layer = nn.Sequential(nn.Conv2d(channel, dim, kernel_size=3, padding=1),\n", | |
" nn.BatchNorm2d(dim),\n", | |
" nn.ReLU(),\n", | |
" nn.Conv2d(dim, dim, kernel_size=3, padding=1),\n", | |
" )\n", | |
" self.input_skip = nn.Sequential(nn.Conv2d(channel, dim, kernel_size=3, padding=1))\n", | |
"\n", | |
" self.residual_conv_1 = ResidualConv(64, 128, 2, 1)\n", | |
" self.residual_conv_2 = ResidualConv(128, 256, 2, 1)\n", | |
" \n", | |
" self.bridge = ResidualConv(256, 512, 2, 1)\n", | |
" \n", | |
" self.upsample_1 = Upsample(512, 512, 2, 2)\n", | |
" self.up_residual_conv1 = ResidualConv(512 + 256, 256, 1, 1)\n", | |
"\n", | |
" self.upsample_2 = Upsample(256, 256, 2, 2)\n", | |
" self.up_residual_conv2 = ResidualConv(256 + 128, 128, 1, 1)\n", | |
"\n", | |
" self.upsample_3 = Upsample(128, 128, 2, 2)\n", | |
" self.up_residual_conv3 = ResidualConv(128 + 64, 64, 1, 1)\n", | |
"\n", | |
" self.output_layer = nn.Sequential(nn.Conv2d(64, 1, 1, 1),\n", | |
" nn.Sigmoid(),\n", | |
" )\n", | |
"\n", | |
" def forward(self, x):\n", | |
" # Encode\n", | |
" x1 = self.input_layer(x) + self.input_skip(x)\n", | |
" x2 = self.residual_conv_1(x1)\n", | |
" x3 = self.residual_conv_2(x2)\n", | |
" # Bridge\n", | |
" x4 = self.bridge(x3)\n", | |
" # Decode\n", | |
" x4 = self.upsample_1(x4)\n", | |
" x5 = torch.cat([x4, x3], dim=1)\n", | |
"\n", | |
" x6 = self.up_residual_conv1(x5)\n", | |
"\n", | |
" x6 = self.upsample_2(x6)\n", | |
" x7 = torch.cat([x6, x2], dim=1)\n", | |
"\n", | |
" x8 = self.up_residual_conv2(x7)\n", | |
"\n", | |
" x8 = self.upsample_3(x8)\n", | |
" x9 = torch.cat([x8, x1], dim=1)\n", | |
"\n", | |
" x10 = self.up_residual_conv3(x9)\n", | |
"\n", | |
" output = self.output_layer(x10)\n", | |
"\n", | |
" return output " | |
], | |
"execution_count": 107, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "XN3sb7y5Yf0t", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"img = torch.ones(1, 3, 224, 224)\n", | |
"resunet = ResUnet(3, 64)" | |
], | |
"execution_count": 108, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "QASqwiPVandU", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "a1e8af07-566a-4c4b-d414-8d8ed7ea5225" | |
}, | |
"source": [ | |
"out = resunet(img)\n", | |
"out.shape" | |
], | |
"execution_count": 109, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([1, 1, 224, 224])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 109 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4FWS_OmKdX4U", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
}, | |
"outputId": "6e2fa2ac-6d11-4499-a9fe-2bcfa97acd16" | |
}, | |
"source": [ | |
"from torchsummary import summary\n", | |
"summary(resunet, input_size=(3, 224, 224))" | |
], | |
"execution_count": 110, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"----------------------------------------------------------------\n", | |
" Layer (type) Output Shape Param #\n", | |
"================================================================\n", | |
" Conv2d-1 [-1, 64, 224, 224] 1,792\n", | |
" BatchNorm2d-2 [-1, 64, 224, 224] 128\n", | |
" ReLU-3 [-1, 64, 224, 224] 0\n", | |
" Conv2d-4 [-1, 64, 224, 224] 36,928\n", | |
" Conv2d-5 [-1, 64, 224, 224] 1,792\n", | |
" BatchNorm2d-6 [-1, 64, 224, 224] 128\n", | |
" ReLU-7 [-1, 64, 224, 224] 0\n", | |
" Conv2d-8 [-1, 128, 112, 112] 73,856\n", | |
" BatchNorm2d-9 [-1, 128, 112, 112] 256\n", | |
" ReLU-10 [-1, 128, 112, 112] 0\n", | |
" Conv2d-11 [-1, 128, 112, 112] 147,584\n", | |
" Conv2d-12 [-1, 128, 112, 112] 73,856\n", | |
" BatchNorm2d-13 [-1, 128, 112, 112] 256\n", | |
" ResidualConv-14 [-1, 128, 112, 112] 0\n", | |
" BatchNorm2d-15 [-1, 128, 112, 112] 256\n", | |
" ReLU-16 [-1, 128, 112, 112] 0\n", | |
" Conv2d-17 [-1, 256, 56, 56] 295,168\n", | |
" BatchNorm2d-18 [-1, 256, 56, 56] 512\n", | |
" ReLU-19 [-1, 256, 56, 56] 0\n", | |
" Conv2d-20 [-1, 256, 56, 56] 590,080\n", | |
" Conv2d-21 [-1, 256, 56, 56] 295,168\n", | |
" BatchNorm2d-22 [-1, 256, 56, 56] 512\n", | |
" ResidualConv-23 [-1, 256, 56, 56] 0\n", | |
" BatchNorm2d-24 [-1, 256, 56, 56] 512\n", | |
" ReLU-25 [-1, 256, 56, 56] 0\n", | |
" Conv2d-26 [-1, 512, 28, 28] 1,180,160\n", | |
" BatchNorm2d-27 [-1, 512, 28, 28] 1,024\n", | |
" ReLU-28 [-1, 512, 28, 28] 0\n", | |
" Conv2d-29 [-1, 512, 28, 28] 2,359,808\n", | |
" Conv2d-30 [-1, 512, 28, 28] 1,180,160\n", | |
" BatchNorm2d-31 [-1, 512, 28, 28] 1,024\n", | |
" ResidualConv-32 [-1, 512, 28, 28] 0\n", | |
" ConvTranspose2d-33 [-1, 512, 56, 56] 1,049,088\n", | |
" Upsample-34 [-1, 512, 56, 56] 0\n", | |
" BatchNorm2d-35 [-1, 768, 56, 56] 1,536\n", | |
" ReLU-36 [-1, 768, 56, 56] 0\n", | |
" Conv2d-37 [-1, 256, 56, 56] 1,769,728\n", | |
" BatchNorm2d-38 [-1, 256, 56, 56] 512\n", | |
" ReLU-39 [-1, 256, 56, 56] 0\n", | |
" Conv2d-40 [-1, 256, 56, 56] 590,080\n", | |
" Conv2d-41 [-1, 256, 56, 56] 1,769,728\n", | |
" BatchNorm2d-42 [-1, 256, 56, 56] 512\n", | |
" ResidualConv-43 [-1, 256, 56, 56] 0\n", | |
" ConvTranspose2d-44 [-1, 256, 112, 112] 262,400\n", | |
" Upsample-45 [-1, 256, 112, 112] 0\n", | |
" BatchNorm2d-46 [-1, 384, 112, 112] 768\n", | |
" ReLU-47 [-1, 384, 112, 112] 0\n", | |
" Conv2d-48 [-1, 128, 112, 112] 442,496\n", | |
" BatchNorm2d-49 [-1, 128, 112, 112] 256\n", | |
" ReLU-50 [-1, 128, 112, 112] 0\n", | |
" Conv2d-51 [-1, 128, 112, 112] 147,584\n", | |
" Conv2d-52 [-1, 128, 112, 112] 442,496\n", | |
" BatchNorm2d-53 [-1, 128, 112, 112] 256\n", | |
" ResidualConv-54 [-1, 128, 112, 112] 0\n", | |
" ConvTranspose2d-55 [-1, 128, 224, 224] 65,664\n", | |
" Upsample-56 [-1, 128, 224, 224] 0\n", | |
" BatchNorm2d-57 [-1, 192, 224, 224] 384\n", | |
" ReLU-58 [-1, 192, 224, 224] 0\n", | |
" Conv2d-59 [-1, 64, 224, 224] 110,656\n", | |
" BatchNorm2d-60 [-1, 64, 224, 224] 128\n", | |
" ReLU-61 [-1, 64, 224, 224] 0\n", | |
" Conv2d-62 [-1, 64, 224, 224] 36,928\n", | |
" Conv2d-63 [-1, 64, 224, 224] 110,656\n", | |
" BatchNorm2d-64 [-1, 64, 224, 224] 128\n", | |
" ResidualConv-65 [-1, 64, 224, 224] 0\n", | |
" Conv2d-66 [-1, 1, 224, 224] 65\n", | |
" Sigmoid-67 [-1, 1, 224, 224] 0\n", | |
"================================================================\n", | |
"Total params: 13,043,009\n", | |
"Trainable params: 13,043,009\n", | |
"Non-trainable params: 0\n", | |
"----------------------------------------------------------------\n", | |
"Input size (MB): 0.57\n", | |
"Forward/backward pass size (MB): 1087.95\n", | |
"Params size (MB): 49.76\n", | |
"Estimated Total Size (MB): 1138.28\n", | |
"----------------------------------------------------------------\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "aX09xeUedd6F", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment