Skip to content

Instantly share code, notes, and snippets.

@Erol444
Last active February 18, 2026 12:34
Show Gist options
  • Select an option

  • Save Erol444/4cbc33c6ac52d83c63f6f9d86ca8a7a4 to your computer and use it in GitHub Desktop.

Select an option

Save Erol444/4cbc33c6ac52d83c63f6f9d86ca8a7a4 to your computer and use it in GitHub Desktop.
SAM3 PVS on Roboflow Serverless interactive OpenCV demo
"""
Note: This gist requires from_sam3() in supervision, currently in PR: https://github.com/roboflow/supervision/pull/2152
Controls:
Left click = positive point (green)
Right click = negative point (red)
1-9 keyboard numbers = class selection
"""
import os
import requests
import base64
import cv2
import numpy as np
import supervision as sv
# From "https://media.roboflow.com/notebooks/examples/dog.jpeg"
image_path = "./dog.jpeg"
image = cv2.imread(image_path)
if image is None:
print(f"Error: Could not load image from {image_path}")
exit(1)
# Encode image as base64
_, buffer = cv2.imencode('.jpg', image)
image_base64 = base64.b64encode(buffer).decode('utf-8')
# Interactive points collection
points_by_class = {} # {class_id: [{'x': int, 'y': int, 'positive': bool}]}
current_class_id = 1
vis_image = image.copy()
url = "https://serverless.roboflow.com/sam3/visual_segment?api_key=" + os.getenv("API_KEY")
def run_inference():
print("Running inference...")
active_classes = []
api_prompts = []
# Sort by class_id to ensure deterministic order
for class_id in sorted(points_by_class.keys()):
class_points = points_by_class[class_id]
if class_points:
active_classes.append(class_id)
# Use points as originally intended for Sam2SegmentationRequest
# defined in inference/core/entities/requests/sam2.py
api_prompts.append({
"points": class_points
})
if not api_prompts:
print("No points to run inference on.")
return
print("api_prompts", api_prompts)
payload = {
"image": {
"type": "base64",
"value": image_base64
},
"prompts": api_prompts,
# "model_id": "sam3/sam3_interactive", # Try this if needed, or rely on default
"format": "json",
"multimask_output": True,
}
try:
response = requests.post(url, json=payload)
if response.status_code != 200:
print(f"Error: {response.text}")
return
data = response.json()
height, width = image.shape[:2]
detections = sv.Detections.from_sam3(data, (width, height))
if detections.is_empty():
print("No detections found.")
# Just show points on original image
annotated_image = image.copy()
else:
# Infer masks per prompt from response structure if possible, or assume based on length
# data['masks'] is expected to be [ [mask1, mask2, mask3], [mask1, mask2, mask3] ... ]
if 'masks' in data and len(data['masks']) > 0:
masks_per_prompt = len(data['masks'][0])
else:
# Fallback if structure is different or flattened (unlikely for SAM2/3 usually)
masks_per_prompt = len(detections) // len(active_classes) if len(active_classes) > 0 else 1
class_ids_list = []
for cid in active_classes:
class_ids_list.extend([cid] * masks_per_prompt)
# Ensure length matches (truncate or pad if assumption failed, to avoid crash)
if len(class_ids_list) == len(detections):
detections.class_id = np.array(class_ids_list)
else:
print(f"Warning: Mismatch in masks count. Detections: {len(detections)}, Expected: {len(class_ids_list)}")
# Fallback: assign 0 or something safe, or just try to fit what we can
detections.class_id = np.array(class_ids_list[:len(detections)])
# Annotate
mask_annotator = sv.MaskAnnotator()
label_annotator = sv.LabelAnnotator()
labels = [f"Class {cid}" for cid in detections.class_id]
annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
# Re-draw the points on top for clarity
for class_id, class_points in points_by_class.items():
for p in class_points:
color = (0, 255, 0) if p['positive'] else (0, 0, 255)
cv2.circle(annotated_image, (p['x'], p['y']), 5, color, -1)
# Ensure we can distinguish classes visually for points? maybe not necessary if masks are colored.
# Update the visualization image
global vis_image
vis_image = annotated_image
cv2.imshow("SAM3 Demo", vis_image)
except Exception as e:
print(f"Inference failed: {e}")
def mouse_callback(event, x, y, flags, param):
global points_by_class, current_class_id
trigger_inference = False
if current_class_id not in points_by_class:
points_by_class[current_class_id] = []
if event == cv2.EVENT_LBUTTONDOWN:
# Positive point - Green
points_by_class[current_class_id].append({"positive": True, "x": x, "y": y})
print(f"Added positive point to class {current_class_id}: ({x}, {y})")
trigger_inference = True
elif event == cv2.EVENT_RBUTTONDOWN:
# Negative point - Red
points_by_class[current_class_id].append({"positive": False, "x": x, "y": y})
print(f"Added negative point to class {current_class_id}: ({x}, {y})")
trigger_inference = True
if trigger_inference:
temp_vis = vis_image.copy()
# Draw all points from all classes
for class_id, class_points in points_by_class.items():
for p in class_points:
color = (0, 255, 0) if p['positive'] else (0, 0, 255)
cv2.circle(temp_vis, (p['x'], p['y']), 5, color, -1)
cv2.putText(temp_vis, f"Class: {current_class_id}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
cv2.imshow("SAM3 Demo", temp_vis)
cv2.waitKey(1) # Force update
run_inference()
cv2.namedWindow("SAM3 Demo")
cv2.setMouseCallback("SAM3 Demo", mouse_callback)
print("Left click: Positive point (Green)")
print("Right click: Negative point (Red)")
print("Press 'q' or ESC to quit")
while True:
temp_vis = vis_image.copy()
cv2.putText(temp_vis, f"Class: {current_class_id}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
cv2.imshow("SAM3 Demo", temp_vis)
key = cv2.waitKey(1) & 0xFF
if ord('1') <= key <= ord('9'):
current_class_id = int(chr(key))
print(f"Switched to Class {current_class_id}")
# Next loop iteration will draw the new class ID
if key == ord('q') or key == 27: # q or ESC
break
cv2.destroyAllWindows()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment