Last active
March 24, 2021 23:38
-
-
Save ismael-elatifi/228397a4d1383b6d07e691071f0bc4a1 to your computer and use it in GitHub Desktop.
Compute minimum spanning tree in binary image (nodes are white pixels) using Prim's algorithm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import numpy as np | |
from heapq import heappush, heappop | |
def find_pixels_in_window(img, x, y, window_size): | |
d = window_size//2 | |
y_start = max(y-d,0) | |
x_start = max(x-d,0) | |
window = img[y_start:y+d+1, x_start:x+d+1] | |
ay, ax = np.where(window > 0) | |
ay += y_start | |
ax += x_start | |
return ax, ay | |
def get_neighbors_and_distance(img, x, y, window_size): | |
ax, ay = find_pixels_in_window(img, x, y, window_size=window_size) | |
def dist(x, y, x2, y2): | |
return abs(x-x2)+abs(y-y2) | |
return [((x2, y2), dist(x, y, x2, y2)) for x2, y2 in zip(ax, ay) if x != x2 or y != y2] | |
def compute_minimum_spanning_tree(img, window_size): | |
""" | |
Compute minimum spanning tree in binary image (nodes are white pixels) using Prim algorithm | |
Look for neighbors inside a window (window_size, window_size) around current pixel | |
""" | |
if img.max() == 0: | |
return dict(), 0, 0 | |
heap = [] | |
ay, ax = np.where(img > 0) # nodes are non null pixels | |
n_nodes = len(ay) | |
cur = ax[0], ay[0] | |
all_nodes = set((x,y) for x,y in zip(ax, ay)) | |
all_nodes.remove(cur) | |
dict_tree = {cur: None} # dict node -> father | |
n_trees = 1 | |
total_distance = 0 | |
while len(dict_tree) < n_nodes: | |
for x_y_adj, dist in get_neighbors_and_distance(img, *cur, window_size=window_size): | |
if x_y_adj not in dict_tree: | |
heappush(heap, (dist, cur, x_y_adj)) | |
if not heap: # to handle multiple disconnected trees | |
n_trees += 1 | |
all_nodes = all_nodes.difference(dict_tree.values()) | |
if not all_nodes: | |
break | |
cur = next(iter(all_nodes)) # start a new tree at node cur | |
all_nodes.remove(cur) | |
continue | |
dist, father, cur = heappop(heap) | |
while cur in dict_tree and heap: | |
dist, father, cur = heappop(heap) | |
if cur not in dict_tree: | |
dict_tree[cur] = father | |
total_distance += dist | |
return dict_tree, total_distance, n_trees | |
def draw_tree(img, dict_tree, show_terminal_edge=False): | |
img = np.zeros_like(img) | |
img_rgb = np.dstack(3*[img]) | |
set_fathers = set(dict_tree.values()) | |
for cur, father in dict_tree.items(): | |
if father is None: | |
continue | |
if show_terminal_edge and cur not in set_fathers: | |
color = (0,0,255) | |
thickness = 2 | |
else: | |
color = (255,255,255) | |
thickness = 1 | |
cv2.line(img_rgb, cur,father,color,thickness=thickness) | |
return img_rgb | |
def make_binary_image_random(height, width, proba_black_pixel): | |
"""Make a binary image with white pixels at random positions""" | |
return ((np.random.rand(height,width)>proba_black_pixel)*255).astype("uint8") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Usage example :