Skip to content

Instantly share code, notes, and snippets.

@rishikksh20
Created August 21, 2020 17:56
ResUnet.ipynb
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "ResUnet.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyPGw/X+XTkEoWN4a6DxKTyZ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/rishikksh20/843039a38ab8f770699dfca90c2718e7/resunet.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "S5AjjEi-mwpl",
"colab_type": "code",
"colab": {}
},
"source": [
"import torch\n",
"import torch.nn as nn"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Cq5vCdbmJ-vB",
"colab_type": "code",
"colab": {}
},
"source": [
"class ResidualConv(nn.Module):\n",
"\n",
" def __init__(self, input_dim, output_dim, stride, padding):\n",
" super(ResidualConv, self).__init__()\n",
"\n",
" self.conv_block = nn.Sequential(nn.BatchNorm2d(input_dim),\n",
" nn.ReLU(),\n",
" nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=padding),\n",
" nn.BatchNorm2d(output_dim),\n",
" nn.ReLU(),\n",
" nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),\n",
" )\n",
" self.conv_skip = nn.Sequential(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),\n",
" nn.BatchNorm2d(output_dim))\n",
"\n",
" def forward(self, x):\n",
"\n",
" return self.conv_block(x) + self.conv_skip(x)"
],
"execution_count": 96,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "wr_EplHHjErx",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "4ded4b26-fb5f-4d20-a1d3-83f0dbb4fe03"
},
"source": [
"# Test ResidualConv Block\n",
"x = torch.ones(1, 64, 224, 224)\n",
"res_conv = ResidualConv(64, 128, 2, 1)\n",
"res_conv(x).shape"
],
"execution_count": 106,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([1, 128, 112, 112])"
]
},
"metadata": {
"tags": []
},
"execution_count": 106
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "O0LwOXx-L4kc",
"colab_type": "code",
"colab": {}
},
"source": [
"class Upsample(nn.Module):\n",
"\n",
" def __init__(self, input_dim, output_dim, kernel, stride):\n",
" super(Upsample, self).__init__()\n",
"\n",
" self.upsample = nn.ConvTranspose2d(input_dim, output_dim, kernel_size=kernel, stride=stride)\n",
"\n",
" def forward(self, x):\n",
" return self.upsample(x)"
],
"execution_count": 45,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "LjEZYML9PoxD",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "6c578426-0d11-458d-afe8-201d670f9c01"
},
"source": [
"# Test UpSample Block\n",
"x = torch.ones(1, 512, 28, 28)\n",
"upsample = Upsample(512, 512, 2, 2)\n",
"upsample(x).shape"
],
"execution_count": 46,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([1, 512, 56, 56])"
]
},
"metadata": {
"tags": []
},
"execution_count": 46
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "CTQ-lYABm42K",
"colab_type": "code",
"colab": {}
},
"source": [
"class ResUnet(nn.Module):\n",
"\n",
" def __init__(self, channel, dim):\n",
" super(ResUnet, self).__init__()\n",
"\n",
" self.input_layer = nn.Sequential(nn.Conv2d(channel, dim, kernel_size=3, padding=1),\n",
" nn.BatchNorm2d(dim),\n",
" nn.ReLU(),\n",
" nn.Conv2d(dim, dim, kernel_size=3, padding=1),\n",
" )\n",
" self.input_skip = nn.Sequential(nn.Conv2d(channel, dim, kernel_size=3, padding=1))\n",
"\n",
" self.residual_conv_1 = ResidualConv(64, 128, 2, 1)\n",
" self.residual_conv_2 = ResidualConv(128, 256, 2, 1)\n",
" \n",
" self.bridge = ResidualConv(256, 512, 2, 1)\n",
" \n",
" self.upsample_1 = Upsample(512, 512, 2, 2)\n",
" self.up_residual_conv1 = ResidualConv(512 + 256, 256, 1, 1)\n",
"\n",
" self.upsample_2 = Upsample(256, 256, 2, 2)\n",
" self.up_residual_conv2 = ResidualConv(256 + 128, 128, 1, 1)\n",
"\n",
" self.upsample_3 = Upsample(128, 128, 2, 2)\n",
" self.up_residual_conv3 = ResidualConv(128 + 64, 64, 1, 1)\n",
"\n",
" self.output_layer = nn.Sequential(nn.Conv2d(64, 1, 1, 1),\n",
" nn.Sigmoid(),\n",
" )\n",
"\n",
" def forward(self, x):\n",
" # Encode\n",
" x1 = self.input_layer(x) + self.input_skip(x)\n",
" x2 = self.residual_conv_1(x1)\n",
" x3 = self.residual_conv_2(x2)\n",
" # Bridge\n",
" x4 = self.bridge(x3)\n",
" # Decode\n",
" x4 = self.upsample_1(x4)\n",
" x5 = torch.cat([x4, x3], dim=1)\n",
"\n",
" x6 = self.up_residual_conv1(x5)\n",
"\n",
" x6 = self.upsample_2(x6)\n",
" x7 = torch.cat([x6, x2], dim=1)\n",
"\n",
" x8 = self.up_residual_conv2(x7)\n",
"\n",
" x8 = self.upsample_3(x8)\n",
" x9 = torch.cat([x8, x1], dim=1)\n",
"\n",
" x10 = self.up_residual_conv3(x9)\n",
"\n",
" output = self.output_layer(x10)\n",
"\n",
" return output "
],
"execution_count": 107,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "XN3sb7y5Yf0t",
"colab_type": "code",
"colab": {}
},
"source": [
"img = torch.ones(1, 3, 224, 224)\n",
"resunet = ResUnet(3, 64)"
],
"execution_count": 108,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "QASqwiPVandU",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "a1e8af07-566a-4c4b-d414-8d8ed7ea5225"
},
"source": [
"out = resunet(img)\n",
"out.shape"
],
"execution_count": 109,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([1, 1, 224, 224])"
]
},
"metadata": {
"tags": []
},
"execution_count": 109
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "4FWS_OmKdX4U",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "6e2fa2ac-6d11-4499-a9fe-2bcfa97acd16"
},
"source": [
"from torchsummary import summary\n",
"summary(resunet, input_size=(3, 224, 224))"
],
"execution_count": 110,
"outputs": [
{
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 224, 224] 1,792\n",
" BatchNorm2d-2 [-1, 64, 224, 224] 128\n",
" ReLU-3 [-1, 64, 224, 224] 0\n",
" Conv2d-4 [-1, 64, 224, 224] 36,928\n",
" Conv2d-5 [-1, 64, 224, 224] 1,792\n",
" BatchNorm2d-6 [-1, 64, 224, 224] 128\n",
" ReLU-7 [-1, 64, 224, 224] 0\n",
" Conv2d-8 [-1, 128, 112, 112] 73,856\n",
" BatchNorm2d-9 [-1, 128, 112, 112] 256\n",
" ReLU-10 [-1, 128, 112, 112] 0\n",
" Conv2d-11 [-1, 128, 112, 112] 147,584\n",
" Conv2d-12 [-1, 128, 112, 112] 73,856\n",
" BatchNorm2d-13 [-1, 128, 112, 112] 256\n",
" ResidualConv-14 [-1, 128, 112, 112] 0\n",
" BatchNorm2d-15 [-1, 128, 112, 112] 256\n",
" ReLU-16 [-1, 128, 112, 112] 0\n",
" Conv2d-17 [-1, 256, 56, 56] 295,168\n",
" BatchNorm2d-18 [-1, 256, 56, 56] 512\n",
" ReLU-19 [-1, 256, 56, 56] 0\n",
" Conv2d-20 [-1, 256, 56, 56] 590,080\n",
" Conv2d-21 [-1, 256, 56, 56] 295,168\n",
" BatchNorm2d-22 [-1, 256, 56, 56] 512\n",
" ResidualConv-23 [-1, 256, 56, 56] 0\n",
" BatchNorm2d-24 [-1, 256, 56, 56] 512\n",
" ReLU-25 [-1, 256, 56, 56] 0\n",
" Conv2d-26 [-1, 512, 28, 28] 1,180,160\n",
" BatchNorm2d-27 [-1, 512, 28, 28] 1,024\n",
" ReLU-28 [-1, 512, 28, 28] 0\n",
" Conv2d-29 [-1, 512, 28, 28] 2,359,808\n",
" Conv2d-30 [-1, 512, 28, 28] 1,180,160\n",
" BatchNorm2d-31 [-1, 512, 28, 28] 1,024\n",
" ResidualConv-32 [-1, 512, 28, 28] 0\n",
" ConvTranspose2d-33 [-1, 512, 56, 56] 1,049,088\n",
" Upsample-34 [-1, 512, 56, 56] 0\n",
" BatchNorm2d-35 [-1, 768, 56, 56] 1,536\n",
" ReLU-36 [-1, 768, 56, 56] 0\n",
" Conv2d-37 [-1, 256, 56, 56] 1,769,728\n",
" BatchNorm2d-38 [-1, 256, 56, 56] 512\n",
" ReLU-39 [-1, 256, 56, 56] 0\n",
" Conv2d-40 [-1, 256, 56, 56] 590,080\n",
" Conv2d-41 [-1, 256, 56, 56] 1,769,728\n",
" BatchNorm2d-42 [-1, 256, 56, 56] 512\n",
" ResidualConv-43 [-1, 256, 56, 56] 0\n",
" ConvTranspose2d-44 [-1, 256, 112, 112] 262,400\n",
" Upsample-45 [-1, 256, 112, 112] 0\n",
" BatchNorm2d-46 [-1, 384, 112, 112] 768\n",
" ReLU-47 [-1, 384, 112, 112] 0\n",
" Conv2d-48 [-1, 128, 112, 112] 442,496\n",
" BatchNorm2d-49 [-1, 128, 112, 112] 256\n",
" ReLU-50 [-1, 128, 112, 112] 0\n",
" Conv2d-51 [-1, 128, 112, 112] 147,584\n",
" Conv2d-52 [-1, 128, 112, 112] 442,496\n",
" BatchNorm2d-53 [-1, 128, 112, 112] 256\n",
" ResidualConv-54 [-1, 128, 112, 112] 0\n",
" ConvTranspose2d-55 [-1, 128, 224, 224] 65,664\n",
" Upsample-56 [-1, 128, 224, 224] 0\n",
" BatchNorm2d-57 [-1, 192, 224, 224] 384\n",
" ReLU-58 [-1, 192, 224, 224] 0\n",
" Conv2d-59 [-1, 64, 224, 224] 110,656\n",
" BatchNorm2d-60 [-1, 64, 224, 224] 128\n",
" ReLU-61 [-1, 64, 224, 224] 0\n",
" Conv2d-62 [-1, 64, 224, 224] 36,928\n",
" Conv2d-63 [-1, 64, 224, 224] 110,656\n",
" BatchNorm2d-64 [-1, 64, 224, 224] 128\n",
" ResidualConv-65 [-1, 64, 224, 224] 0\n",
" Conv2d-66 [-1, 1, 224, 224] 65\n",
" Sigmoid-67 [-1, 1, 224, 224] 0\n",
"================================================================\n",
"Total params: 13,043,009\n",
"Trainable params: 13,043,009\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.57\n",
"Forward/backward pass size (MB): 1087.95\n",
"Params size (MB): 49.76\n",
"Estimated Total Size (MB): 1138.28\n",
"----------------------------------------------------------------\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aX09xeUedd6F",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment