Skip to content

Instantly share code, notes, and snippets.

@secemp9
Created May 31, 2025 20:52
Show Gist options
  • Save secemp9/9ad641535e656eb50868bc884d4db0a3 to your computer and use it in GitHub Desktop.
Save secemp9/9ad641535e656eb50868bc884d4db0a3 to your computer and use it in GitHub Desktop.
attempt at reproducing https://www.groundlight.ai/blog/how-vlm-works-tokens 's interactive vlm demo to visualize tokens on their patch/predicted region on an image
import torch, torch.nn.functional as F
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
MODEL_ID = "llava-hf/llava-1.5-7b-hf"
device = "cuda"
model = LlavaForConditionalGeneration.from_pretrained(
MODEL_ID, torch_dtype=torch.float16, device_map="auto")
processor = AutoProcessor.from_pretrained(MODEL_ID)
tok = processor.tokenizer
print(model.get_input_embeddings().weight)
W_lang = F.normalize(model.get_input_embeddings().weight, dim=-1).to(device)
img = Image.open("1280px-Labrador_Retriever_portrait.jpg").convert("RGB")
with torch.no_grad():
px = processor(images=img, text="", return_tensors="pt").pixel_values.to(device)
patch_emb = model.vision_tower(px).last_hidden_state[:, 1:] # drop [CLS]
patch_lang = F.normalize(model.multi_modal_projector(patch_emb)
.squeeze(0), dim=-1) # (N, 4096)
# ---------- similarity & top-k ----------
sim = patch_lang @ W_lang.T # (N, |V|)
topk = 8
vals, ids = sim.topk(topk, dim=-1) # vals, ids: (N, 8)
# ---------- nice printing helper ----------
def clean_labels(vals, ids, thresh=0.001):
"""
vals : (N, k) cosine-similarity scores, highest first
ids : (N, k) vocab indices that produced those scores
"""
keep = vals > thresh
out = []
for scores, idx_row, mask in zip(vals, ids, keep):
if not mask.any():
out.append([]) # no confident words
continue
# convert *only* the ids we keep
toks = tok.convert_ids_to_tokens(idx_row[mask].tolist())
cleaned = []
for t in toks:
if t is None: # <-- NEW guard
continue
if t.startswith("<") and t.endswith(">"):
continue # drop <s>, <unk>, <0xAB>, <image>, …
if t.startswith("▁"):
t = t[1:] # drop sentencepiece underline
if t: # skip empty strings
cleaned.append(t)
out.append(cleaned)
return out
words_per_patch = clean_labels(vals, ids)
print(words_per_patch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment