Last active
September 29, 2020 06:51
-
-
Save shijianjian/f2545b5308d3975371900c0b61c708e5 to your computer and use it in GitHub Desktop.
Cirrus OCT ROI segmentation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Cirrus OCT ROI segmentation visualization for the intermedia steps. | |
Author: Jian Shi | |
Email: [email protected] | |
""" | |
idx = 8 | |
c, c_out = denoise_and_dilate(cu[idx]) | |
mask, out = crop_contours(c_out, cu[idx]) | |
plt.figure(figsize=(16, 12)) | |
plt.imshow( | |
np.concatenate([ | |
cu[idx], | |
np.ones((c.shape[0], 5)) * 255, | |
c, | |
np.ones((c.shape[0], 5)) * 255, | |
c_out, | |
np.ones((c.shape[0], 5)) * 255, | |
out | |
], axis=1) | |
, cmap='gray') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Cirrus OCT ROI segmentation. | |
Author: Jian Shi | |
Email: [email protected] | |
""" | |
import numpy as np | |
import pandas as pd | |
import os | |
import cv2 | |
def denoise_and_dilate(input_img): | |
ret, thresh = cv2.threshold(input_img, 65, 255, 0) | |
thresh = cv2.fastNlMeansDenoising(thresh, 10, 25, 25) | |
c = cv2.Canny(thresh, 100, 200) | |
kernel = np.ones((31, 31), np.uint8) | |
c_out = cv2.dilate(c, kernel, 30) | |
return c, c_out | |
def crop_contours(processed_img, input_img): | |
contours, hierarchy = cv2.findContours(processed_img, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) | |
mask = np.zeros_like(processed_img) | |
cv2.drawContours(mask, contours, -1, 255, -1) | |
# Fill the inner holes | |
upper = mask.argmax(axis=0) | |
# suppress unselected areas to 0. | |
upper = np.where(mask.sum(axis=0) == 0, len(mask), upper) | |
bottom = len(mask) - np.flip(mask, axis=0).argmax(axis=0) | |
upper_coords = np.expand_dims(np.stack([np.linspace(0, len(mask[1]) - 1, num=len(mask[1])), upper], axis=1), 1).astype(np.int32) | |
bot_coords = np.expand_dims(np.stack([np.linspace(0, len(mask[1]) - 1, num=len(mask[1])), bottom], axis=1), 1).astype(np.int32) | |
left_coords = np.expand_dims(np.stack([np.zeros(bottom[0] - upper[0]), np.linspace(upper[0], bottom[0] - 1, num=bottom[0] - upper[0])], axis=1), 1).astype(np.int32) | |
right_coords = np.expand_dims(np.stack([ | |
np.ones(bottom[-1] - upper[-1]) * len(bottom), | |
np.linspace(upper[-1], bottom[-1] - 1, num=bottom[-1] - upper[-1])[::-1] | |
], axis=1), 1).astype(np.int32) | |
mask = np.zeros_like(processed_img) | |
cv2.drawContours(mask, [np.concatenate([left_coords, bot_coords, right_coords, upper_coords], axis=0)], -1, 255, -1) | |
out = np.zeros_like(processed_img) | |
out[mask == 255] = input_img[mask == 255] | |
return mask, out | |
def get_output(input_image): | |
cu = input_image.copy() | |
cu_out = [] | |
for idx in range(len(cu)): | |
c, c_out = denoise_and_dilate(cu[idx]) | |
mask, out = crop_contours(c_out, cu[idx]) | |
cu_out.append(out) | |
cu_out = np.array(cu_out) | |
return cu_out | |
def crop_volume(cu_out, tol=2): | |
cu_out_2 = cu_out.copy() | |
cu_out_2_mean = (cu_out_2.mean(axis=(0, 2)) > tol).tolist() | |
cu_out_2 = cu_out_2[:, cu_out_2_mean.index(True):len(cu_out_2_mean) - cu_out_2_mean[::-1].index(True)] | |
return cu_out_2 | |
def get_image(data_path): | |
if data_path.endswith(".npy"): | |
image = np.load(data_path) | |
else: | |
with open(data_path) as f: | |
image = np.fromfile(f, np.uint8) | |
if len(image) == 67108864: | |
image = image.reshape((128, 1024, 512)) | |
elif len(image) == 40960000: | |
image = image.reshape((200, 1024, 200)) | |
else: | |
raise ValueError(f"Unexpected length {len(image)}, for {data_path}") | |
return image | |
def main_process(data_path): | |
if os.path.exists(data_path.replace('.img', '_seg.npy')): | |
print('Skipping', data_path) | |
else: | |
img = get_image(data_path) | |
out = get_output(img) | |
np.save(data_path.replace('.img', '_seg.npy'), out) | |
print('Finished', data_path) | |
def main(csv, path_prefix=None, process=20, protocols=['Macular Cube 512x128', 'Macular Cube 200x200', 'Optic Disc Cube 200x200']): | |
df = pd.read_csv(csv) | |
df = df[df['area'].isin(protocols)].reset_index(drop=True) | |
datapaths = ("" if path_prefix is None else path_prefix + df['root'] + os.path.sep + df['filename']).values | |
from multiprocessing import Pool | |
with Pool(process) as p: | |
print(p.map(main_process, datapaths)) | |
if __name__ == '__main__': | |
main( | |
"YOUR.csv", | |
"YOUR ROOT", | |
process=35 | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment