Created
May 31, 2025 20:53
-
-
Save secemp9/a53a2a967188d2721a4d21450871e7bf to your computer and use it in GitHub Desktop.
attempt at reproducing https://www.groundlight.ai/blog/how-vlm-works-tokens 's interactive vlm demo to visualize tokens on their patch/predicted region on an image
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch, torch.nn.functional as F | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor | |
from PIL import Image | |
REVISION = "2025-04-14" # lock to a known good tag | |
MODEL_ID = "vikhyatk/moondream2" | |
device = "cuda" # or "cpu" / bitsandbytes / GGUF, etc. | |
# 1️⃣ load model + text tokenizer | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
revision=REVISION, | |
trust_remote_code=True, # Moondream ships custom classes | |
torch_dtype=torch.float16, | |
device_map="auto", | |
) | |
tok = AutoTokenizer.from_pretrained(MODEL_ID, revision=REVISION, trust_remote_code=True) | |
from types import MethodType | |
import torch.nn as nn | |
def _get_input_embeddings(self): | |
# wrap the parameter so HF utilities that expect an nn.Embedding still work | |
return nn.Embedding.from_pretrained(self.model.text.wte, freeze=True) | |
model.get_input_embeddings = MethodType(_get_input_embeddings, model) | |
print(model.get_input_embeddings().weight) | |
W_lang = F.normalize(model.get_input_embeddings().weight, dim=-1).to(device) | |
# ~ # 4️⃣ pick an image | |
img = Image.open("1280px-Labrador_Retriever_portrait.jpg").convert("RGB") | |
with torch.no_grad(): | |
# Use Moondream's encode_image method | |
encoded_image = model.encode_image(img) | |
# If you need the raw vision embeddings for analysis: | |
img_emb = model.model._run_vision_encoder(img) # This gives you the vision embeddings | |
# Continue with your analysis... | |
patch_lang = F.normalize(img_emb, dim=-1) | |
sim = patch_lang @ W_lang.T | |
vals, ids = sim.topk(8, dim=-1) | |
# 8️⃣ helper --------------------------------------------------------------- | |
def labels_from_ids(vals, ids, thresh=0.15): | |
""" | |
Convert the top-k token IDs for every patch into *readable* words. | |
Uses the tokenizer’s .decode, which automatically strips the Ġ / ▁ | |
“new-word” markers for you. | |
""" | |
out = [] | |
keep = vals > thresh # mask of tokens we want to keep | |
for v, i, m in zip(vals, ids, keep): | |
if not m.any(): # no token passed the threshold | |
out.append([]) | |
continue | |
# decode the surviving ids in one go → a clean utf-8 string | |
text = tok.decode(i[m].tolist(), skip_special_tokens=True).strip() | |
# break the string back into individual words | |
out.append(text.split()) | |
return out | |
words_per_patch = labels_from_ids(vals, ids, thresh=0.0001) | |
print(words_per_patch) # e.g. ['dog', 'black', 'fur'], ... | |
# --------------------------------------------------------------- | |
# 9️⃣ visualise – draw the highest-probability word per patch | |
# --------------------------------------------------------------- | |
from moondream.torch.vision import prepare_crops # part of the repo | |
from PIL import Image, ImageDraw, ImageFont | |
# ⓵ your code up to `words_per_patch` ... | |
# Skip the global CLS token – everything afterwards is one patch | |
patch_words = words_per_patch[1:] | |
# ⓶ Get the tiling that Moondream decided to use | |
_, tiling = prepare_crops(img, model.config.vision, device=device) | |
ps = model.config.vision.enc_patch_size # 14 | |
crop_patches = model.config.vision.crop_size // ps # 27 | |
win_patches = crop_patches - 2 * model.config.vision.overlap_margin # 19 | |
grid_w = tiling[1] * win_patches | |
grid_h = tiling[0] * win_patches | |
assert len(patch_words) == grid_w * grid_h | |
# ⓷ Draw on a canvas that matches the (resized) image Moondream worked with | |
canvas = img.resize((grid_w * ps, grid_h * ps), Image.LANCZOS) | |
draw = ImageDraw.Draw(canvas) | |
try: | |
font = ImageFont.truetype("DejaVuSans.ttf", size=10) | |
except IOError: # fallback if the font isn't available | |
font = ImageFont.load_default() | |
for idx, toks in enumerate(patch_words): | |
if not toks: # nothing above the threshold for this patch | |
continue | |
row, col = divmod(idx, grid_w) | |
x = col * ps + ps // 2 # patch centre | |
y = row * ps + ps // 2 | |
draw.text((x, y), toks[0], fill="white", | |
font=font, anchor="mm", | |
stroke_width=1, stroke_fill="black") # small outline for contrast | |
canvas.show("moondream_patch_labels.png") | |
print("✅ overlay saved as moondream_patch_labels.png") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment