Skip to content

Instantly share code, notes, and snippets.

@secemp9
Created May 31, 2025 20:53
Show Gist options
  • Save secemp9/a53a2a967188d2721a4d21450871e7bf to your computer and use it in GitHub Desktop.
Save secemp9/a53a2a967188d2721a4d21450871e7bf to your computer and use it in GitHub Desktop.
attempt at reproducing https://www.groundlight.ai/blog/how-vlm-works-tokens 's interactive vlm demo to visualize tokens on their patch/predicted region on an image
import torch, torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
from PIL import Image
REVISION = "2025-04-14" # lock to a known good tag
MODEL_ID = "vikhyatk/moondream2"
device = "cuda" # or "cpu" / bitsandbytes / GGUF, etc.
# 1️⃣ load model + text tokenizer
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
revision=REVISION,
trust_remote_code=True, # Moondream ships custom classes
torch_dtype=torch.float16,
device_map="auto",
)
tok = AutoTokenizer.from_pretrained(MODEL_ID, revision=REVISION, trust_remote_code=True)
from types import MethodType
import torch.nn as nn
def _get_input_embeddings(self):
# wrap the parameter so HF utilities that expect an nn.Embedding still work
return nn.Embedding.from_pretrained(self.model.text.wte, freeze=True)
model.get_input_embeddings = MethodType(_get_input_embeddings, model)
print(model.get_input_embeddings().weight)
W_lang = F.normalize(model.get_input_embeddings().weight, dim=-1).to(device)
# ~ # 4️⃣ pick an image
img = Image.open("1280px-Labrador_Retriever_portrait.jpg").convert("RGB")
with torch.no_grad():
# Use Moondream's encode_image method
encoded_image = model.encode_image(img)
# If you need the raw vision embeddings for analysis:
img_emb = model.model._run_vision_encoder(img) # This gives you the vision embeddings
# Continue with your analysis...
patch_lang = F.normalize(img_emb, dim=-1)
sim = patch_lang @ W_lang.T
vals, ids = sim.topk(8, dim=-1)
# 8️⃣ helper ---------------------------------------------------------------
def labels_from_ids(vals, ids, thresh=0.15):
"""
Convert the top-k token IDs for every patch into *readable* words.
Uses the tokenizer’s .decode, which automatically strips the Ġ / ▁
“new-word” markers for you.
"""
out = []
keep = vals > thresh # mask of tokens we want to keep
for v, i, m in zip(vals, ids, keep):
if not m.any(): # no token passed the threshold
out.append([])
continue
# decode the surviving ids in one go → a clean utf-8 string
text = tok.decode(i[m].tolist(), skip_special_tokens=True).strip()
# break the string back into individual words
out.append(text.split())
return out
words_per_patch = labels_from_ids(vals, ids, thresh=0.0001)
print(words_per_patch) # e.g. ['dog', 'black', 'fur'], ...
# ---------------------------------------------------------------
# 9️⃣ visualise – draw the highest-probability word per patch
# ---------------------------------------------------------------
from moondream.torch.vision import prepare_crops # part of the repo
from PIL import Image, ImageDraw, ImageFont
# ⓵ your code up to `words_per_patch` ...
# Skip the global CLS token – everything afterwards is one patch
patch_words = words_per_patch[1:]
# ⓶ Get the tiling that Moondream decided to use
_, tiling = prepare_crops(img, model.config.vision, device=device)
ps = model.config.vision.enc_patch_size # 14
crop_patches = model.config.vision.crop_size // ps # 27
win_patches = crop_patches - 2 * model.config.vision.overlap_margin # 19
grid_w = tiling[1] * win_patches
grid_h = tiling[0] * win_patches
assert len(patch_words) == grid_w * grid_h
# ⓷ Draw on a canvas that matches the (resized) image Moondream worked with
canvas = img.resize((grid_w * ps, grid_h * ps), Image.LANCZOS)
draw = ImageDraw.Draw(canvas)
try:
font = ImageFont.truetype("DejaVuSans.ttf", size=10)
except IOError: # fallback if the font isn't available
font = ImageFont.load_default()
for idx, toks in enumerate(patch_words):
if not toks: # nothing above the threshold for this patch
continue
row, col = divmod(idx, grid_w)
x = col * ps + ps // 2 # patch centre
y = row * ps + ps // 2
draw.text((x, y), toks[0], fill="white",
font=font, anchor="mm",
stroke_width=1, stroke_fill="black") # small outline for contrast
canvas.show("moondream_patch_labels.png")
print("✅ overlay saved as moondream_patch_labels.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment