Skip to content

Instantly share code, notes, and snippets.

@thisismattmiller
Created December 17, 2024 00:14
Show Gist options
  • Save thisismattmiller/50e77d2af8cd2480732067b75ac65ca3 to your computer and use it in GitHub Desktop.
Save thisismattmiller/50e77d2af8cd2480732067b75ac65ca3 to your computer and use it in GitHub Desktop.
Using SAM2 model on images with auto mask
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
import glob
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"using device: {device}")
if device.type == "cuda":
# use bfloat16 for the entire notebook
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
print(
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
"give numerically different outputs and sometimes degraded performance on MPS. "
"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
)
np.random.seed(3)
def show_anns(anns, borders=True):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:, :, 3] = 0
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.concatenate([np.random.random(3), [0.5]])
img[m] = color_mask
if borders:
import cv2
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1)
ax.imshow(img)
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
sam2_checkpoint = "/home/matt/sam2_test/sam2/checkpoints/sam2.1_hiera_large.pt"
model_cfg = "sam2.1_hiera_l.yaml"
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(
model=sam2,
points_per_side=16,
points_per_batch=32,
pred_iou_thresh=0.1,
stability_score_thresh=0.92,
stability_score_offset=0.7,
crop_n_layers=1,
box_nms_thresh=0.1,
crop_n_points_downscale_factor=2,
min_mask_region_area=99,
use_m2m=True,
)
for file in glob.glob('*.png'):
image = Image.open(file)
fileid=file.split("/")[-1].split(".")[0]
image = np.array(image.convert("RGB"))
masks = mask_generator.generate(image)
print(len(masks))
print(masks[0].keys())
print(len(masks[0]['segmentation'][0]))
# to see the masks un comment this
# plt.figure(figsize=(10, 10))
# plt.imshow(image)
# show_anns(masks)
# plt.axis('off')
# plt.show()
# how to loop through the masks
# counter = 0
# if len(masks) > 0:
# os.makedirs(f"/mnt/f/pbot_masks/{fileid}")
# for mask in masks:
# print(f"{fileid}-{counter}")
# image2 = Image.open(f'/mnt/f/woodblocks/{fileid}.jpg')
# img = image2.convert("RGBA")
# pixdata = img.load()
# width, height = img.size
# for y in range(height):
# # print("y==",y)
# for x in range(width):
# # print(masks[0]['segmentation'][y][x])
# try:
# if mask['segmentation'][y][x] == False:
# pixdata[x, y] = (255, 255, 255, 0)
# except:
# continue
# x1 = mask['bbox'][0]
# y1 = mask['bbox'][1]
# x2 = mask['bbox'][0] + mask['bbox'][2]
# y2 = mask['bbox'][1] + mask['bbox'][3]
# img = img.crop([x1, y1, x2, y2])
# img.save(f"/mnt/f/pbot_masks/{fileid}/{counter}.png", "PNG")
# counter=counter+1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment