Created
December 17, 2024 00:14
-
-
Save thisismattmiller/50e77d2af8cd2480732067b75ac65ca3 to your computer and use it in GitHub Desktop.
Using SAM2 model on images with auto mask
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
# if using Apple MPS, fall back to CPU for unsupported ops | |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
import numpy as np | |
import torch | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import glob | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
elif torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
print(f"using device: {device}") | |
if device.type == "cuda": | |
# use bfloat16 for the entire notebook | |
torch.autocast("cuda", dtype=torch.bfloat16).__enter__() | |
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) | |
if torch.cuda.get_device_properties(0).major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
elif device.type == "mps": | |
print( | |
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might " | |
"give numerically different outputs and sometimes degraded performance on MPS. " | |
"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion." | |
) | |
np.random.seed(3) | |
def show_anns(anns, borders=True): | |
if len(anns) == 0: | |
return | |
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) | |
ax = plt.gca() | |
ax.set_autoscale_on(False) | |
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) | |
img[:, :, 3] = 0 | |
for ann in sorted_anns: | |
m = ann['segmentation'] | |
color_mask = np.concatenate([np.random.random(3), [0.5]]) | |
img[m] = color_mask | |
if borders: | |
import cv2 | |
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) | |
# Try to smooth contours | |
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] | |
cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) | |
ax.imshow(img) | |
from sam2.build_sam import build_sam2 | |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
sam2_checkpoint = "/home/matt/sam2_test/sam2/checkpoints/sam2.1_hiera_large.pt" | |
model_cfg = "sam2.1_hiera_l.yaml" | |
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) | |
mask_generator = SAM2AutomaticMaskGenerator( | |
model=sam2, | |
points_per_side=16, | |
points_per_batch=32, | |
pred_iou_thresh=0.1, | |
stability_score_thresh=0.92, | |
stability_score_offset=0.7, | |
crop_n_layers=1, | |
box_nms_thresh=0.1, | |
crop_n_points_downscale_factor=2, | |
min_mask_region_area=99, | |
use_m2m=True, | |
) | |
for file in glob.glob('*.png'): | |
image = Image.open(file) | |
fileid=file.split("/")[-1].split(".")[0] | |
image = np.array(image.convert("RGB")) | |
masks = mask_generator.generate(image) | |
print(len(masks)) | |
print(masks[0].keys()) | |
print(len(masks[0]['segmentation'][0])) | |
# to see the masks un comment this | |
# plt.figure(figsize=(10, 10)) | |
# plt.imshow(image) | |
# show_anns(masks) | |
# plt.axis('off') | |
# plt.show() | |
# how to loop through the masks | |
# counter = 0 | |
# if len(masks) > 0: | |
# os.makedirs(f"/mnt/f/pbot_masks/{fileid}") | |
# for mask in masks: | |
# print(f"{fileid}-{counter}") | |
# image2 = Image.open(f'/mnt/f/woodblocks/{fileid}.jpg') | |
# img = image2.convert("RGBA") | |
# pixdata = img.load() | |
# width, height = img.size | |
# for y in range(height): | |
# # print("y==",y) | |
# for x in range(width): | |
# # print(masks[0]['segmentation'][y][x]) | |
# try: | |
# if mask['segmentation'][y][x] == False: | |
# pixdata[x, y] = (255, 255, 255, 0) | |
# except: | |
# continue | |
# x1 = mask['bbox'][0] | |
# y1 = mask['bbox'][1] | |
# x2 = mask['bbox'][0] + mask['bbox'][2] | |
# y2 = mask['bbox'][1] + mask['bbox'][3] | |
# img = img.crop([x1, y1, x2, y2]) | |
# img.save(f"/mnt/f/pbot_masks/{fileid}/{counter}.png", "PNG") | |
# counter=counter+1 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment