Last active
February 18, 2026 12:34
-
-
Save Erol444/4cbc33c6ac52d83c63f6f9d86ca8a7a4 to your computer and use it in GitHub Desktop.
SAM3 PVS on Roboflow Serverless interactive OpenCV demo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Note: This gist requires from_sam3() in supervision, currently in PR: https://github.com/roboflow/supervision/pull/2152 | |
| Controls: | |
| Left click = positive point (green) | |
| Right click = negative point (red) | |
| 1-9 keyboard numbers = class selection | |
| """ | |
| import os | |
| import requests | |
| import base64 | |
| import cv2 | |
| import numpy as np | |
| import supervision as sv | |
| # From "https://media.roboflow.com/notebooks/examples/dog.jpeg" | |
| image_path = "./dog.jpeg" | |
| image = cv2.imread(image_path) | |
| if image is None: | |
| print(f"Error: Could not load image from {image_path}") | |
| exit(1) | |
| # Encode image as base64 | |
| _, buffer = cv2.imencode('.jpg', image) | |
| image_base64 = base64.b64encode(buffer).decode('utf-8') | |
| # Interactive points collection | |
| points_by_class = {} # {class_id: [{'x': int, 'y': int, 'positive': bool}]} | |
| current_class_id = 1 | |
| vis_image = image.copy() | |
| url = "https://serverless.roboflow.com/sam3/visual_segment?api_key=" + os.getenv("API_KEY") | |
| def run_inference(): | |
| print("Running inference...") | |
| active_classes = [] | |
| api_prompts = [] | |
| # Sort by class_id to ensure deterministic order | |
| for class_id in sorted(points_by_class.keys()): | |
| class_points = points_by_class[class_id] | |
| if class_points: | |
| active_classes.append(class_id) | |
| # Use points as originally intended for Sam2SegmentationRequest | |
| # defined in inference/core/entities/requests/sam2.py | |
| api_prompts.append({ | |
| "points": class_points | |
| }) | |
| if not api_prompts: | |
| print("No points to run inference on.") | |
| return | |
| print("api_prompts", api_prompts) | |
| payload = { | |
| "image": { | |
| "type": "base64", | |
| "value": image_base64 | |
| }, | |
| "prompts": api_prompts, | |
| # "model_id": "sam3/sam3_interactive", # Try this if needed, or rely on default | |
| "format": "json", | |
| "multimask_output": True, | |
| } | |
| try: | |
| response = requests.post(url, json=payload) | |
| if response.status_code != 200: | |
| print(f"Error: {response.text}") | |
| return | |
| data = response.json() | |
| height, width = image.shape[:2] | |
| detections = sv.Detections.from_sam3(data, (width, height)) | |
| if detections.is_empty(): | |
| print("No detections found.") | |
| # Just show points on original image | |
| annotated_image = image.copy() | |
| else: | |
| # Infer masks per prompt from response structure if possible, or assume based on length | |
| # data['masks'] is expected to be [ [mask1, mask2, mask3], [mask1, mask2, mask3] ... ] | |
| if 'masks' in data and len(data['masks']) > 0: | |
| masks_per_prompt = len(data['masks'][0]) | |
| else: | |
| # Fallback if structure is different or flattened (unlikely for SAM2/3 usually) | |
| masks_per_prompt = len(detections) // len(active_classes) if len(active_classes) > 0 else 1 | |
| class_ids_list = [] | |
| for cid in active_classes: | |
| class_ids_list.extend([cid] * masks_per_prompt) | |
| # Ensure length matches (truncate or pad if assumption failed, to avoid crash) | |
| if len(class_ids_list) == len(detections): | |
| detections.class_id = np.array(class_ids_list) | |
| else: | |
| print(f"Warning: Mismatch in masks count. Detections: {len(detections)}, Expected: {len(class_ids_list)}") | |
| # Fallback: assign 0 or something safe, or just try to fit what we can | |
| detections.class_id = np.array(class_ids_list[:len(detections)]) | |
| # Annotate | |
| mask_annotator = sv.MaskAnnotator() | |
| label_annotator = sv.LabelAnnotator() | |
| labels = [f"Class {cid}" for cid in detections.class_id] | |
| annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections) | |
| annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections, labels=labels) | |
| # Re-draw the points on top for clarity | |
| for class_id, class_points in points_by_class.items(): | |
| for p in class_points: | |
| color = (0, 255, 0) if p['positive'] else (0, 0, 255) | |
| cv2.circle(annotated_image, (p['x'], p['y']), 5, color, -1) | |
| # Ensure we can distinguish classes visually for points? maybe not necessary if masks are colored. | |
| # Update the visualization image | |
| global vis_image | |
| vis_image = annotated_image | |
| cv2.imshow("SAM3 Demo", vis_image) | |
| except Exception as e: | |
| print(f"Inference failed: {e}") | |
| def mouse_callback(event, x, y, flags, param): | |
| global points_by_class, current_class_id | |
| trigger_inference = False | |
| if current_class_id not in points_by_class: | |
| points_by_class[current_class_id] = [] | |
| if event == cv2.EVENT_LBUTTONDOWN: | |
| # Positive point - Green | |
| points_by_class[current_class_id].append({"positive": True, "x": x, "y": y}) | |
| print(f"Added positive point to class {current_class_id}: ({x}, {y})") | |
| trigger_inference = True | |
| elif event == cv2.EVENT_RBUTTONDOWN: | |
| # Negative point - Red | |
| points_by_class[current_class_id].append({"positive": False, "x": x, "y": y}) | |
| print(f"Added negative point to class {current_class_id}: ({x}, {y})") | |
| trigger_inference = True | |
| if trigger_inference: | |
| temp_vis = vis_image.copy() | |
| # Draw all points from all classes | |
| for class_id, class_points in points_by_class.items(): | |
| for p in class_points: | |
| color = (0, 255, 0) if p['positive'] else (0, 0, 255) | |
| cv2.circle(temp_vis, (p['x'], p['y']), 5, color, -1) | |
| cv2.putText(temp_vis, f"Class: {current_class_id}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) | |
| cv2.imshow("SAM3 Demo", temp_vis) | |
| cv2.waitKey(1) # Force update | |
| run_inference() | |
| cv2.namedWindow("SAM3 Demo") | |
| cv2.setMouseCallback("SAM3 Demo", mouse_callback) | |
| print("Left click: Positive point (Green)") | |
| print("Right click: Negative point (Red)") | |
| print("Press 'q' or ESC to quit") | |
| while True: | |
| temp_vis = vis_image.copy() | |
| cv2.putText(temp_vis, f"Class: {current_class_id}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) | |
| cv2.imshow("SAM3 Demo", temp_vis) | |
| key = cv2.waitKey(1) & 0xFF | |
| if ord('1') <= key <= ord('9'): | |
| current_class_id = int(chr(key)) | |
| print(f"Switched to Class {current_class_id}") | |
| # Next loop iteration will draw the new class ID | |
| if key == ord('q') or key == 27: # q or ESC | |
| break | |
| cv2.destroyAllWindows() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment