Created
November 21, 2018 12:13
-
-
Save Con-Mi/4d92af62adb784a5353ff7cf19d6d099 to your computer and use it in GitHub Desktop.
Convert a PyTorch binary to C++ readable.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cmake_minimum_required(VERSION 3.0 FATAL_ERROR) | |
project(custom_ops) | |
set(CMAKE_PREFIX_PATH /home/marios-cellink/libtorch) | |
set(CMAKE_CXX_STANDARD 11) | |
set(CMAKE_CXX_STANDARD_REQUIRED TRUE) | |
find_package(Torch REQUIRED) | |
find_package(OpenCV 3.4.3 REQUIRED) | |
add_executable(testing cpp_model.cc) | |
message(STATUS "OpenCV library status:") | |
message(STATUS " config: ${OpenCV_DIR}") | |
message(STATUS " version: ${OpenCV_VERSION}") | |
message(STATUS " libraries: ${OpenCV_LIBS}") | |
message(STATUS " include path: ${OpenCV_INCLUDE_DIRS}") | |
message(STATUS "TORCHLIB: ${TORCH_LIBRARIES}") | |
target_link_libraries(testing ${OpenCV_LIBS}) | |
target_link_libraries(testing ${TORCH_LIBRARIES}) | |
set_property(TARGET testing PROPERTY CXX_STANDARD 11) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torchvision import transforms | |
from helper import load_model | |
from var_dense_linknet_model import denseLinkModel | |
segm_model = denseLinkModel(input_channels=3, pretrained=True) | |
segm_model = load_model(segm_model, model_dir="./var_dense_linknet_384_sgd_bce_20epchs.pt") | |
example = torch.ones(1, 3, 384, 384) | |
traced_script_module = torch.jit.trace(segm_model, example) | |
traced_script_module.save("./jit_pred_model.pt") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <memory> | |
#include <torch/script.h> | |
#include <opencv2/core.hpp> | |
#include <opencv2/highgui.hpp> | |
#include <opencv2/opencv.hpp> | |
#include <opencv2/imgproc.hpp> | |
#include <opencv2/imgcodecs.hpp> | |
int main(int argc, const char* argv[]) { | |
cv::Mat img = cv::imread("../0010_r.png", cv::IMREAD_UNCHANGED); | |
// img.convertTo(img, CV_32FC3, 1/255.0); // Convert the image into floats | |
at::Tensor tensor_img = torch::from_blob(img.data, {1, 3, img.rows, img.cols}, at::kByte).clone(); | |
tensor_img = tensor_img.to(at::kFloat); | |
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("../jit_pred_model.pt"); | |
std::vector<torch::jit::IValue> inputs; // Converts the image into floats. This does the same. | |
inputs.push_back(tensor_img); | |
auto output = module->forward(inputs).toTensor(); | |
std::vector<int> size = {384, 384}; | |
output = output.squeeze(); | |
output = at::sigmoid(output); | |
std::cout << output << std::endl; | |
cv::Mat img_out(size, CV_32F, output.data<float>()); | |
std::cout << img_out << std::endl; | |
cv::imshow("Original Image", img); | |
cv::imshow("Output", img_out); | |
cv::waitKey(21000); | |
return 0; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def jaccard(y_true, y_pred): | |
intersection = (y_true * y_pred).sum() | |
union = y_true.sum() + y_pred.sum() - intersection | |
return (intersection + 1e-15) / (union + 1e-15) | |
def dice(y_true, y_pred): | |
return (2 * (y_true * y_pred).sum() + 1e-15) / (y_true.sum() + y_pred.sum() + 1e-15) | |
def load_model(cust_model, model_dir="dense_segm.pt", map_location_device="cpu"): | |
if map_location_device == "cpu": | |
cust_model.load_state_dict(torch.load(model_dir, map_location=map_location_device)) | |
elif map_location_device == "gpu": | |
cust_model.load_state_dict(torch.load(model_dir)) | |
cust_model.eval() | |
return cust_model | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torchvision import transforms | |
from helper import load_model | |
from var_dense_linknet_model import denseLinkModel | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
from matplotlib import pyplot as plt | |
import seaborn as sns | |
segm_model = denseLinkModel(input_channels=3, pretrained=True) | |
segm_model = load_model(segm_model, model_dir="./var_dense_linknet_384_sgd_bce_20epchs.pt") | |
trf = transforms.Compose([ transforms.Resize(size=(384, 384)), transforms.ToTensor() ]) | |
img = cv2.imread("./0010_.png", cv2.IMREAD_UNCHANGED) | |
pil_img = Image.fromarray("./0010_.png") | |
img_in = trf(pil_img) | |
img_in = img_in.unsqueeze(dim=0) | |
out = segm_model(img_in) | |
out = out.squeeze() | |
out = torch.sigmoid(out) | |
out = out.detach().numpy() | |
cv2.imshow("Original Image", img) | |
cv2.imshow("Gray Scale Image", out) | |
cv2.waitKey(7000) | |
sns.heatmap(out) | |
plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchvision import models | |
from torch import nn | |
import torch | |
class ConvEluGrNorm(nn.Module): | |
def __init__(self, inp_chnl, out_chnl): | |
super(ConvEluGrNorm, self).__init__() | |
self.conv = nn.Conv2d(in_channels=inp_chnl, out_channels=out_chnl, kernel_size=3, padding=1, bias=False) | |
self.norm = nn.GroupNorm(num_groups=16, num_channels=out_chnl) | |
self.elu = nn.ELU(inplace=True) | |
def forward(self, x): | |
out = self.conv(x) | |
out = self.norm(out) | |
out = self.elu(out) | |
return out | |
class UpsampleLayer(nn.Sequential): | |
def __init__(self, in_chnl, mid_chnl, out_chnl, transp=False): | |
super(UpsampleLayer, self).__init__() | |
if not transp: | |
self.block = nn.Sequential( | |
nn.GroupNorm(num_groups=16, num_channels=in_chnl), | |
nn.Upsample(scale_factor=2, mode="nearest"), | |
ConvEluGrNorm(in_chnl, mid_chnl), | |
ConvEluGrNorm(mid_chnl, out_chnl) | |
) | |
else: | |
self.block = nn.Sequential( | |
ConvEluGrNorm(in_chnl, mid_chnl), | |
nn.ConvTranspose2d(in_channels=mid_chnl, out_channels=out_chnl, | |
kernel_size=4, stride=2, padding=1, bias=False), | |
nn.ELU(inplace=True) | |
) | |
class TransitionLayer(nn.Sequential): | |
def __init__(self, in_chnl, out_chnl): | |
super(TransitionLayer, self).__init__() | |
self.block = nn.Sequential( | |
nn.GroupNorm(num_groups = 16, num_channels=in_chnl), | |
nn.ELU(inplace=True), | |
nn.Conv2d(in_chnl, out_chnl, kernel_size=1, padding=0, bias=False), | |
nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | |
) | |
class Bottleneck(nn.Sequential): | |
def __init__(self, in_chnl, out_chnl): | |
super(Bottleneck, self).__init__() | |
self.block = nn.Sequential( | |
nn.GroupNorm(num_groups=16, num_channels=in_chnl), | |
nn.ELU(), | |
nn.Conv2d(in_channels=in_chnl, out_channels=in_chnl, kernel_size=1, padding=0, bias=False), | |
nn.GroupNorm(num_groups=16, num_channels=in_chnl), | |
nn.ELU(), | |
ConvEluGrNorm(inp_chnl=in_chnl, out_chnl=out_chnl) | |
) | |
class DenseSegmModel(nn.Module): | |
def __init__(self, input_channels, num_filters=32, num_classes=1, pretrained=False): | |
super(DenseSegmModel, self).__init__() | |
encoder = models.densenet121(pretrained=pretrained).features | |
self.layer1 = nn.Sequential( | |
nn.Conv2d(in_channels = input_channels, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False), | |
nn.GroupNorm(num_groups=16, num_channels=64), | |
nn.ELU(inplace=True), | |
encoder[3] | |
) | |
self.layer2 = encoder[4:6] | |
self.layer3 = encoder[6:8] | |
self.layer4 = encoder[8:10] | |
self.layer5 = encoder[10] | |
self.transition = TransitionLayer(in_chnl=1024, out_chnl=1024) | |
self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | |
self.bottleneck = Bottleneck(in_chnl=1024+num_filters*8, out_chnl=num_filters*8) | |
self.center = UpsampleLayer(in_chnl=1024, mid_chnl=num_filters*8, out_chnl=num_filters*8) | |
self.dec5 = UpsampleLayer(1024 + num_filters*8, num_filters*8, num_filters*8) | |
self.dec4 = UpsampleLayer(512 + num_filters*8, num_filters*8, num_filters*8) | |
self.dec3 = UpsampleLayer(256 + num_filters*8, num_filters*8, num_filters*8) | |
self.dec2 = UpsampleLayer(128 + num_filters*8, num_filters*2, num_filters*2) | |
self.dec1 = UpsampleLayer(64+num_filters*2, num_filters, num_filters) | |
self.dec0 = UpsampleLayer(num_filters, num_filters, num_filters) | |
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1) | |
def forward(self, x): | |
conv1 = self.layer1(x) | |
conv2 = self.layer2(conv1) | |
conv3 = self.layer3(conv2) | |
conv4 = self.layer4(conv3) | |
conv5 = self.layer5(conv4) | |
out = self.transition(conv5) | |
center = self.center(out) | |
#dec5 = self.bottleneck(self.dec5(torch.cat([center, conv5], 1))) | |
dec5 = self.bottleneck(torch.cat([center, conv5], 1)) | |
dec4 = self.dec4(torch.cat([dec5, conv4], 1)) | |
dec3 = self.dec3(torch.cat([dec4, conv3], 1)) | |
dec2 = self.dec2(torch.cat([dec3, conv2], 1)) | |
dec1 = self.dec1(torch.cat([dec2, conv1], 1)) | |
dec0 = self.dec0(dec1) | |
return self.final(dec0) | |
def denseLinkModel(input_channels, pretrained=False, num_classes=1): | |
return DenseSegmModel(input_channels=input_channels, pretrained=pretrained, num_classes=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment