Skip to content

Instantly share code, notes, and snippets.

@escorciav
Last active December 9, 2024 06:27
Show Gist options
  • Save escorciav/5794efb59f1ef0dd18b84425e11a1c91 to your computer and use it in GitHub Desktop.
Save escorciav/5794efb59f1ef0dd18b84425e11a1c91 to your computer and use it in GitHub Desktop.
synthclip support open_clip & clip_benchmark
""" Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats.
src/open_clip/convert.py
"""
from typing import Union
import torch
import numpy as np
from .model import CLIP, CustomTextCLIP
from .transformer import TextTransformer, Transformer
@torch.no_grad()
def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
""" Load weights from .npz checkpoints for official Google big_vision image-text models
Currently the SigLIP source models are supported and a CustomTextCLIP destination model
w/ timm image encoder.
"""
from timm.layers import resample_patch_embed, resample_abs_pos_embed
def _n2p(w, t=True, idx=None):
if idx is not None:
w = w[idx]
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
if w.ndim == 4:
w = w.transpose([3, 2, 0, 1])
elif w.ndim == 3:
w = w.transpose([2, 0, 1])
elif w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w)
w = np.load(checkpoint_path)
interpolation = 'bilinear'
antialias = False
def _convert_timm_img(module, prefix):
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]:
embed_conv_w = resample_patch_embed(
embed_conv_w,
module.patch_embed.proj.weight.shape[-2:],
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
module.patch_embed.proj.weight.copy_(embed_conv_w)
module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
if module.cls_token is not None:
module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
if pos_embed_w.shape != module.pos_embed.shape:
assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}'
num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1)
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w,
new_size=module.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
module.pos_embed.copy_(pos_embed_w)
mha_sub, b_sub, ln1_sub = (0, 0, 1)
for i, block in enumerate(module.blocks.children()):
if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
block_prefix = f'{prefix}Transformer/encoderblock/'
idx = i
else:
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
idx = None
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))
module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
if module.attn_pool is not None:
block_prefix = f'{prefix}MAPHead_0/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
module.attn_pool.kv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
module.attn_pool.kv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
for r in range(2):
getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
def _convert_openclip_transformer(module: Transformer, prefix):
for i, block in enumerate(module.resblocks.children()):
block_prefix = f'{prefix}encoderblock_{i}/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.in_proj_weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.in_proj_bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale']))
block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias']))
block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel']))
block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias']))
block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel']))
block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias']))
def _convert_openclip_txt(module: TextTransformer, prefix):
module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False))
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0)
module.positional_embedding.copy_(pos_embed_w)
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
if module.text_projection is not None:
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))
_convert_timm_img(model.visual.trunk, 'img/')
_convert_openclip_txt(model.text, 'txt/')
model.logit_bias.copy_(_n2p(w['b'])[0])
model.logit_scale.copy_(_n2p(w['t'])[0])
@torch.no_grad()
def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True):
def _convert_timm_img(state_dict):
if fastvit:
from timm.models.fastvit import checkpoint_filter_fn
else:
from timm.models.vision_transformer_hybrid import checkpoint_filter_fn
timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk)
timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()}
return timm_state_dict
def _convert_openclip_txt(state_dict, prefix='text_encoder.'):
text_dict = {}
for k, v in state_dict.items():
if not k.startswith(prefix):
continue
k = k.replace(prefix, '')
k = k.replace('projection_layer', 'text_projection')
k = k.replace('embedding_layer', 'token_embedding')
if k.startswith('positional_embedding.pos_embed.pos_embed'):
k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding')
v = v.squeeze()
k = k.replace('final_layer_norm', 'ln_final')
k = k.replace('pre_norm_mha.0', 'ln_1')
k = k.replace('pre_norm_mha.1', 'attn')
k = k.replace('pre_norm_ffn.0', 'ln_2')
k = k.replace('pre_norm_ffn.1', 'mlp.c_fc')
k = k.replace('pre_norm_ffn.4', 'mlp.c_proj')
k = k.replace('qkv_proj.weight', 'in_proj_weight')
k = k.replace('qkv_proj.bias', 'in_proj_bias')
k = k.replace('transformer.', 'transformer.resblocks.')
text_dict['text.' + k] = v
return text_dict
image_dict = _convert_timm_img(state_dict)
text_dict = _convert_openclip_txt(state_dict)
out_dict = {**image_dict, **text_dict}
out_dict['logit_scale'] = state_dict['logit_scale']
return out_dict
@torch.no_grad()
def convert_synthclip_state_dict(model: CustomTextCLIP, state_dict):
out_dict = {}
for k, v in state_dict.items():
# print(k, v.shape)
if k.startswith('visual.blocks.'):
k = k.replace('visual.blocks.', 'visual.trunk.blocks.')
out_dict[k] = v
elif k.startswith('visual.patch_embed.'):
# visual.patch_embed.proj.weight,
# visual.patch_embed.proj.bias,
k = k.replace('visual.patch_embed.', 'visual.trunk.patch_embed.')
out_dict[k] = v
elif k.startswith('visual.norm.'):
# visual.norm.weight,
# visual.norm.bias,
k = k.replace('visual.norm.', 'visual.trunk.norm.')
out_dict[k] = v
elif k.startswith('visual.cls_token'):
k = k.replace('visual.cls_token', 'visual.trunk.cls_token')
out_dict[k] = v
elif k.startswith('visual.pos_embed'):
k = k.replace('visual.pos_embed', 'visual.trunk.pos_embed')
out_dict[k] = v
elif k.startswith('image_projection'):
k = k.replace('image_projection', 'visual.head.proj.weight')
out_dict[k] = v.transpose(0, 1)
else:
out_dict[k] = v
out_dict['logit_scale'] = state_dict['logit_scale']
# print()
# for k, v in model.named_parameters():
# print(k, v.shape)
return out_dict
def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
# Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported)
state_dict = convert_mobile_clip_state_dict(model, state_dict)
if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
# convert b model
state_dict = convert_mobile_clip_state_dict(model, state_dict,
fastvit=False)
if 'image_projection' in state_dict:
state_dict = convert_synthclip_state_dict(model, state_dict)
return state_dict
import os
import torch
import open_clip
from PIL import Image
import torchvision.transforms as transforms
from collections import OrderedDict
from clip_benchmark.models.synthclip import CLIP_VITB16
# checkpoint_path = "/home/itanh0b/Projects/CLIP_benchmark/cache/SynthCI30M-ViT-B-16"
checkpoint_path = "./logs/synthclip-30m"
img_path = "./open_clip/docs/CLIP.png"
use_clip_benchmark = True
device = "cuda"
def load_synthclip(model_name, pretrained, device):
if model_name == "ViT-B-16":
model = CLIP_VITB16()
tokenizer = open_clip.get_tokenizer(model_name)
if pretrained:
state_dict = torch.load(os.path.join(pretrained, "checkpoint_best.pt"))["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace("module.", "")
new_state_dict[name] = v
load_status = model.load_state_dict(new_state_dict)
print(load_status)
model.to(device)
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose(
[
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]
)
return model, transform, tokenizer
if not use_clip_benchmark:
print('Load synthclip as per example...')
model = torch.nn.DataParallel(CLIP_VITB16())
checkpoint = torch.load(os.path.join(checkpoint_path, "checkpoint_best.pt"), map_location=device)
load_status = model.load_state_dict(checkpoint["state_dict"])
print(load_status)
model = model.module
model.to(device)
tokenizer = open_clip.get_tokenizer("ViT-B-16")
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
transform = transforms.Compose(
[
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
# dunno why I need that but whatever XD. EOM - Victor
lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x, # force RGB
normalize,
]
)
else:
model, transform, tokenizer = load_synthclip("ViT-B-16", checkpoint_path, device)
print('Load & preprocess image...')
image = Image.open(img_path)
image = image.convert('RGB')
image = transform(image).unsqueeze(0).to("cuda")
print('Tokenize text...')
text = tokenizer(["a diagram", "a dog", "a cat"]).to("cuda")
#############
model.eval()
#############
print('Fwd-pass model...')
autocast = torch.cuda.amp.autocast
with torch.no_grad(), autocast():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
logit_scale = model.logit_scale.exp()
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (logit_scale * image_features @ text_features.T).softmax(dim=-1)
print("Label probs:", text_probs) # prints: [[0.0046, 0.0878, 0.9076]]
import torch
import open_clip
from PIL import Image
# model_arch, model_path = 'ViT-B-16', 'datacomp_l_s1b_b8k'
model_arch = "ViT-B-16-trick"
model_path = "/fast_scratch/datasets/model-zoo/synthclip-30m/checkpoint_best.pt"
img_path = "./docs/CLIP.png"
device = 'cuda'
if model_arch != "ViT-B-16-trick":
model, _, preprocess = open_clip.create_model_and_transforms(
model_arch, pretrained=model_path, load_weights_only=False)
else:
model, _ = open_clip.create_model_from_pretrained(
model_arch, pretrained=model_path, load_weights_only=False)
import torchvision.transforms as transforms
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
preprocess = transforms.Compose(
[
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.Lambda(lambda img: img.convert('RGB')),
transforms.ToTensor(),
normalize,
]
)
print(f"{preprocess}=")
tokenizer = open_clip.get_tokenizer(model_arch)
print("Model setup 🎉")
model = model.cuda()
model.eval()
print('Load & preprocess image...')
image = preprocess(Image.open(img_path).convert('RGB')).unsqueeze(0).cuda()
text = tokenizer(["a diagram", "a dog", "a cat"]).cuda()
print('Fwd-pass model...')
autocast = torch.amp.autocast
with torch.no_grad(), autocast(device, dtype=torch.float32):
image_features = model.encode_image(image)
text_features = model.encode_text(text)
logit_scale = model.logit_scale.exp()
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (logit_scale * image_features @ text_features.T).softmax(dim=-1)
print("Label probs:", text_probs) # prints: [[1., 0., 0.]]
{
"embed_dim": 512,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "vit_base_patch16_224",
"timm_model_pretrained": false
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment