Created
May 31, 2025 20:52
-
-
Save secemp9/9ad641535e656eb50868bc884d4db0a3 to your computer and use it in GitHub Desktop.
attempt at reproducing https://www.groundlight.ai/blog/how-vlm-works-tokens 's interactive vlm demo to visualize tokens on their patch/predicted region on an image
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch, torch.nn.functional as F | |
from transformers import AutoProcessor, LlavaForConditionalGeneration | |
from PIL import Image | |
MODEL_ID = "llava-hf/llava-1.5-7b-hf" | |
device = "cuda" | |
model = LlavaForConditionalGeneration.from_pretrained( | |
MODEL_ID, torch_dtype=torch.float16, device_map="auto") | |
processor = AutoProcessor.from_pretrained(MODEL_ID) | |
tok = processor.tokenizer | |
print(model.get_input_embeddings().weight) | |
W_lang = F.normalize(model.get_input_embeddings().weight, dim=-1).to(device) | |
img = Image.open("1280px-Labrador_Retriever_portrait.jpg").convert("RGB") | |
with torch.no_grad(): | |
px = processor(images=img, text="", return_tensors="pt").pixel_values.to(device) | |
patch_emb = model.vision_tower(px).last_hidden_state[:, 1:] # drop [CLS] | |
patch_lang = F.normalize(model.multi_modal_projector(patch_emb) | |
.squeeze(0), dim=-1) # (N, 4096) | |
# ---------- similarity & top-k ---------- | |
sim = patch_lang @ W_lang.T # (N, |V|) | |
topk = 8 | |
vals, ids = sim.topk(topk, dim=-1) # vals, ids: (N, 8) | |
# ---------- nice printing helper ---------- | |
def clean_labels(vals, ids, thresh=0.001): | |
""" | |
vals : (N, k) cosine-similarity scores, highest first | |
ids : (N, k) vocab indices that produced those scores | |
""" | |
keep = vals > thresh | |
out = [] | |
for scores, idx_row, mask in zip(vals, ids, keep): | |
if not mask.any(): | |
out.append([]) # no confident words | |
continue | |
# convert *only* the ids we keep | |
toks = tok.convert_ids_to_tokens(idx_row[mask].tolist()) | |
cleaned = [] | |
for t in toks: | |
if t is None: # <-- NEW guard | |
continue | |
if t.startswith("<") and t.endswith(">"): | |
continue # drop <s>, <unk>, <0xAB>, <image>, … | |
if t.startswith("▁"): | |
t = t[1:] # drop sentencepiece underline | |
if t: # skip empty strings | |
cleaned.append(t) | |
out.append(cleaned) | |
return out | |
words_per_patch = clean_labels(vals, ids) | |
print(words_per_patch) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment