Skip to content

Instantly share code, notes, and snippets.

@magesh-technovator
Last active May 20, 2021 10:42
Show Gist options
  • Save magesh-technovator/dcf793dfaf0f2e3a3a6bfb99ee3c51a4 to your computer and use it in GitHub Desktop.
Save magesh-technovator/dcf793dfaf0f2e3a3a6bfb99ee3c51a4 to your computer and use it in GitHub Desktop.
torch Segmentation inference
import torch
import matplotlib.pyplot as plt
import cv2
import pandas as pd
import time
model_path = "<model path>"
image_path = "<image_path>"
# Load the trained model
if torch.cuda.is_available():
model = torch.load(model_path)
else:
model = torch.load(model_path, map_location=torch.device('cpu'))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# Set the model to evaluate mode
model.eval()
# Read a sample image and mask from the data-set
originalImage = cv2.imread(image_path)
# Resize image
img = cv2.resize(originalImage, (256, 256), cv2.INTER_AREA).transpose(2,0,1)
img = img.reshape(1, 3, img.shape[1],img.shape[2])
start_time = time.time()
with torch.no_grad():
if torch.cuda.is_available():
a = model(torch.from_numpy(img).to(device).type(torch.cuda.FloatTensor)/255)
else:
a = model(torch.from_numpy(img).to(device).type(torch.FloatTensor)/255)
print("--- %s seconds ---" % (time.time() - start_time))
outImage = a['out'].cpu().detach().numpy()[0]
plt.imshow(outImage.transpose(1,2,0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment