Last active
December 9, 2024 06:27
-
-
Save escorciav/5794efb59f1ef0dd18b84425e11a1c91 to your computer and use it in GitHub Desktop.
synthclip support open_clip & clip_benchmark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats. | |
src/open_clip/convert.py | |
""" | |
from typing import Union | |
import torch | |
import numpy as np | |
from .model import CLIP, CustomTextCLIP | |
from .transformer import TextTransformer, Transformer | |
@torch.no_grad() | |
def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): | |
""" Load weights from .npz checkpoints for official Google big_vision image-text models | |
Currently the SigLIP source models are supported and a CustomTextCLIP destination model | |
w/ timm image encoder. | |
""" | |
from timm.layers import resample_patch_embed, resample_abs_pos_embed | |
def _n2p(w, t=True, idx=None): | |
if idx is not None: | |
w = w[idx] | |
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: | |
w = w.flatten() | |
if t: | |
if w.ndim == 4: | |
w = w.transpose([3, 2, 0, 1]) | |
elif w.ndim == 3: | |
w = w.transpose([2, 0, 1]) | |
elif w.ndim == 2: | |
w = w.transpose([1, 0]) | |
return torch.from_numpy(w) | |
w = np.load(checkpoint_path) | |
interpolation = 'bilinear' | |
antialias = False | |
def _convert_timm_img(module, prefix): | |
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) | |
if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: | |
embed_conv_w = resample_patch_embed( | |
embed_conv_w, | |
module.patch_embed.proj.weight.shape[-2:], | |
interpolation=interpolation, | |
antialias=antialias, | |
verbose=True, | |
) | |
module.patch_embed.proj.weight.copy_(embed_conv_w) | |
module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) | |
if module.cls_token is not None: | |
module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) | |
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) | |
if pos_embed_w.shape != module.pos_embed.shape: | |
assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' | |
num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) | |
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights | |
pos_embed_w, | |
new_size=module.patch_embed.grid_size, | |
num_prefix_tokens=num_prefix_tokens, | |
interpolation=interpolation, | |
antialias=antialias, | |
verbose=True, | |
) | |
module.pos_embed.copy_(pos_embed_w) | |
mha_sub, b_sub, ln1_sub = (0, 0, 1) | |
for i, block in enumerate(module.blocks.children()): | |
if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w: | |
block_prefix = f'{prefix}Transformer/encoderblock/' | |
idx = i | |
else: | |
block_prefix = f'{prefix}Transformer/encoderblock_{i}/' | |
idx = None | |
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' | |
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx)) | |
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx)) | |
block.attn.qkv.weight.copy_(torch.cat([ | |
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')])) | |
block.attn.qkv.bias.copy_(torch.cat([ | |
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')])) | |
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1)) | |
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx)) | |
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx)) | |
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx)) | |
for r in range(2): | |
getattr(block.mlp, f'fc{r + 1}').weight.copy_( | |
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx)) | |
getattr(block.mlp, f'fc{r + 1}').bias.copy_( | |
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx)) | |
module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) | |
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) | |
if module.attn_pool is not None: | |
block_prefix = f'{prefix}MAPHead_0/' | |
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' | |
module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) | |
module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) | |
module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) | |
module.attn_pool.kv.weight.copy_(torch.cat([ | |
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) | |
module.attn_pool.kv.bias.copy_(torch.cat([ | |
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) | |
module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) | |
module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) | |
module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) | |
module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) | |
for r in range(2): | |
getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) | |
getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) | |
def _convert_openclip_transformer(module: Transformer, prefix): | |
for i, block in enumerate(module.resblocks.children()): | |
block_prefix = f'{prefix}encoderblock_{i}/' | |
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' | |
block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) | |
block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) | |
block.attn.in_proj_weight.copy_(torch.cat([ | |
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) | |
block.attn.in_proj_bias.copy_(torch.cat([ | |
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) | |
block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) | |
block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) | |
block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) | |
block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) | |
block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) | |
block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) | |
block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) | |
block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) | |
def _convert_openclip_txt(module: TextTransformer, prefix): | |
module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) | |
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) | |
module.positional_embedding.copy_(pos_embed_w) | |
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') | |
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) | |
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) | |
if module.text_projection is not None: | |
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) | |
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) | |
_convert_timm_img(model.visual.trunk, 'img/') | |
_convert_openclip_txt(model.text, 'txt/') | |
model.logit_bias.copy_(_n2p(w['b'])[0]) | |
model.logit_scale.copy_(_n2p(w['t'])[0]) | |
@torch.no_grad() | |
def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True): | |
def _convert_timm_img(state_dict): | |
if fastvit: | |
from timm.models.fastvit import checkpoint_filter_fn | |
else: | |
from timm.models.vision_transformer_hybrid import checkpoint_filter_fn | |
timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk) | |
timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()} | |
return timm_state_dict | |
def _convert_openclip_txt(state_dict, prefix='text_encoder.'): | |
text_dict = {} | |
for k, v in state_dict.items(): | |
if not k.startswith(prefix): | |
continue | |
k = k.replace(prefix, '') | |
k = k.replace('projection_layer', 'text_projection') | |
k = k.replace('embedding_layer', 'token_embedding') | |
if k.startswith('positional_embedding.pos_embed.pos_embed'): | |
k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding') | |
v = v.squeeze() | |
k = k.replace('final_layer_norm', 'ln_final') | |
k = k.replace('pre_norm_mha.0', 'ln_1') | |
k = k.replace('pre_norm_mha.1', 'attn') | |
k = k.replace('pre_norm_ffn.0', 'ln_2') | |
k = k.replace('pre_norm_ffn.1', 'mlp.c_fc') | |
k = k.replace('pre_norm_ffn.4', 'mlp.c_proj') | |
k = k.replace('qkv_proj.weight', 'in_proj_weight') | |
k = k.replace('qkv_proj.bias', 'in_proj_bias') | |
k = k.replace('transformer.', 'transformer.resblocks.') | |
text_dict['text.' + k] = v | |
return text_dict | |
image_dict = _convert_timm_img(state_dict) | |
text_dict = _convert_openclip_txt(state_dict) | |
out_dict = {**image_dict, **text_dict} | |
out_dict['logit_scale'] = state_dict['logit_scale'] | |
return out_dict | |
@torch.no_grad() | |
def convert_synthclip_state_dict(model: CustomTextCLIP, state_dict): | |
out_dict = {} | |
for k, v in state_dict.items(): | |
# print(k, v.shape) | |
if k.startswith('visual.blocks.'): | |
k = k.replace('visual.blocks.', 'visual.trunk.blocks.') | |
out_dict[k] = v | |
elif k.startswith('visual.patch_embed.'): | |
# visual.patch_embed.proj.weight, | |
# visual.patch_embed.proj.bias, | |
k = k.replace('visual.patch_embed.', 'visual.trunk.patch_embed.') | |
out_dict[k] = v | |
elif k.startswith('visual.norm.'): | |
# visual.norm.weight, | |
# visual.norm.bias, | |
k = k.replace('visual.norm.', 'visual.trunk.norm.') | |
out_dict[k] = v | |
elif k.startswith('visual.cls_token'): | |
k = k.replace('visual.cls_token', 'visual.trunk.cls_token') | |
out_dict[k] = v | |
elif k.startswith('visual.pos_embed'): | |
k = k.replace('visual.pos_embed', 'visual.trunk.pos_embed') | |
out_dict[k] = v | |
elif k.startswith('image_projection'): | |
k = k.replace('image_projection', 'visual.head.proj.weight') | |
out_dict[k] = v.transpose(0, 1) | |
else: | |
out_dict[k] = v | |
out_dict['logit_scale'] = state_dict['logit_scale'] | |
# print() | |
# for k, v in model.named_parameters(): | |
# print(k, v.shape) | |
return out_dict | |
def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict): | |
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict: | |
# Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported) | |
state_dict = convert_mobile_clip_state_dict(model, state_dict) | |
if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict: | |
# convert b model | |
state_dict = convert_mobile_clip_state_dict(model, state_dict, | |
fastvit=False) | |
if 'image_projection' in state_dict: | |
state_dict = convert_synthclip_state_dict(model, state_dict) | |
return state_dict |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import open_clip | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from collections import OrderedDict | |
from clip_benchmark.models.synthclip import CLIP_VITB16 | |
# checkpoint_path = "/home/itanh0b/Projects/CLIP_benchmark/cache/SynthCI30M-ViT-B-16" | |
checkpoint_path = "./logs/synthclip-30m" | |
img_path = "./open_clip/docs/CLIP.png" | |
use_clip_benchmark = True | |
device = "cuda" | |
def load_synthclip(model_name, pretrained, device): | |
if model_name == "ViT-B-16": | |
model = CLIP_VITB16() | |
tokenizer = open_clip.get_tokenizer(model_name) | |
if pretrained: | |
state_dict = torch.load(os.path.join(pretrained, "checkpoint_best.pt"))["state_dict"] | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
name = k.replace("module.", "") | |
new_state_dict[name] = v | |
load_status = model.load_state_dict(new_state_dict) | |
print(load_status) | |
model.to(device) | |
normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
transform = transforms.Compose( | |
[ | |
transforms.Resize(224), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
normalize, | |
] | |
) | |
return model, transform, tokenizer | |
if not use_clip_benchmark: | |
print('Load synthclip as per example...') | |
model = torch.nn.DataParallel(CLIP_VITB16()) | |
checkpoint = torch.load(os.path.join(checkpoint_path, "checkpoint_best.pt"), map_location=device) | |
load_status = model.load_state_dict(checkpoint["state_dict"]) | |
print(load_status) | |
model = model.module | |
model.to(device) | |
tokenizer = open_clip.get_tokenizer("ViT-B-16") | |
normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
) | |
transform = transforms.Compose( | |
[ | |
transforms.Resize(224), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
# dunno why I need that but whatever XD. EOM - Victor | |
lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x, # force RGB | |
normalize, | |
] | |
) | |
else: | |
model, transform, tokenizer = load_synthclip("ViT-B-16", checkpoint_path, device) | |
print('Load & preprocess image...') | |
image = Image.open(img_path) | |
image = image.convert('RGB') | |
image = transform(image).unsqueeze(0).to("cuda") | |
print('Tokenize text...') | |
text = tokenizer(["a diagram", "a dog", "a cat"]).to("cuda") | |
############# | |
model.eval() | |
############# | |
print('Fwd-pass model...') | |
autocast = torch.cuda.amp.autocast | |
with torch.no_grad(), autocast(): | |
image_features = model.encode_image(image) | |
text_features = model.encode_text(text) | |
logit_scale = model.logit_scale.exp() | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
text_probs = (logit_scale * image_features @ text_features.T).softmax(dim=-1) | |
print("Label probs:", text_probs) # prints: [[0.0046, 0.0878, 0.9076]] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import open_clip | |
from PIL import Image | |
# model_arch, model_path = 'ViT-B-16', 'datacomp_l_s1b_b8k' | |
model_arch = "ViT-B-16-trick" | |
model_path = "/fast_scratch/datasets/model-zoo/synthclip-30m/checkpoint_best.pt" | |
img_path = "./docs/CLIP.png" | |
device = 'cuda' | |
if model_arch != "ViT-B-16-trick": | |
model, _, preprocess = open_clip.create_model_and_transforms( | |
model_arch, pretrained=model_path, load_weights_only=False) | |
else: | |
model, _ = open_clip.create_model_from_pretrained( | |
model_arch, pretrained=model_path, load_weights_only=False) | |
import torchvision.transforms as transforms | |
normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
preprocess = transforms.Compose( | |
[ | |
transforms.Resize(224), | |
transforms.CenterCrop(224), | |
transforms.Lambda(lambda img: img.convert('RGB')), | |
transforms.ToTensor(), | |
normalize, | |
] | |
) | |
print(f"{preprocess}=") | |
tokenizer = open_clip.get_tokenizer(model_arch) | |
print("Model setup 🎉") | |
model = model.cuda() | |
model.eval() | |
print('Load & preprocess image...') | |
image = preprocess(Image.open(img_path).convert('RGB')).unsqueeze(0).cuda() | |
text = tokenizer(["a diagram", "a dog", "a cat"]).cuda() | |
print('Fwd-pass model...') | |
autocast = torch.amp.autocast | |
with torch.no_grad(), autocast(device, dtype=torch.float32): | |
image_features = model.encode_image(image) | |
text_features = model.encode_text(text) | |
logit_scale = model.logit_scale.exp() | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
text_probs = (logit_scale * image_features @ text_features.T).softmax(dim=-1) | |
print("Label probs:", text_probs) # prints: [[1., 0., 0.]] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"embed_dim": 512, | |
"quick_gelu": true, | |
"vision_cfg": { | |
"image_size": 224, | |
"timm_model_name": "vit_base_patch16_224", | |
"timm_model_pretrained": false | |
}, | |
"text_cfg": { | |
"context_length": 77, | |
"vocab_size": 49408, | |
"width": 512, | |
"heads": 8, | |
"layers": 12 | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment