Skip to content

Instantly share code, notes, and snippets.

@tomaarsen
Created September 24, 2025 16:08
Show Gist options
  • Save tomaarsen/f02b628162b8d49a9f93d40758af6ef3 to your computer and use it in GitHub Desktop.
Save tomaarsen/f02b628162b8d49a9f93d40758af6ef3 to your computer and use it in GitHub Desktop.
Script to update all E5-NL models to be nicely integrated with Sentence Transformers
import re
from huggingface_hub import get_collection, ModelCard
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Normalize
collection = get_collection(collection_slug="clips/e5-nl-68be9d3760240ce5c7d9f831")
ST_SNIPPET_PATTERN = r"""\
from sentence_transformers import SentenceTransformer
model = SentenceTransformer\((?:'|")([a-zA-Z0-9_\/\.-]+?)(?:'|")\)
input_texts = \[
(?:'|")(.+?)(?:'|"),
(?:'|")(.+?)(?:'|"),
(?:'|")(.+?)(?:'|"),
(?:'|")(.+?)(?:'|")
\]
embeddings = model.encode\(input_texts, normalize_embeddings=True\)
"""
ST_SNIPPET_TEMPLATE = """
from sentence_transformers import SentenceTransformer
# Load the model from Hugging Face
model = SentenceTransformer("{model_name}")
# Perform inference using encode_query/encode_document for retrieval,
# or encode_query for general purpose embeddings. Prompt prefixes
# are automatically added with these two methods.
queries = [
{query1},
{query2},
]
documents = [
{document1},
{document2},
]
query_embeddings = model.encode_query(queries)
document_embeddings = model.encode_document(documents)
print(query_embeddings.shape, document_embeddings.shape)
{shapes}
similarities = model.similarity(query_embeddings, document_embeddings)
{similarities}
"""
FINISHED_MODELS = []
for item in collection.items:
if item.item_type != "model":
continue
model_id = item.item_id
if model_id in FINISHED_MODELS:
continue
model = SentenceTransformer(
model_id,
prompts={
"query": "query: ",
"document": "passage: ",
},
)
model.add_module("2", Normalize())
model_card = ModelCard.load(model_id)
model_card.data.library_name = "sentence-transformers"
model_card.data.language = "nl"
tags = model_card.data.tags or []
if "transformers" not in tags:
tags.append("transformers")
model_card.data.tags = tags
content = model_card.content
match = re.search(ST_SNIPPET_PATTERN, content)
if match:
model_name = match.group(1)
queries = [match.group(2), match.group(3)]
documents = [match.group(4), match.group(5)]
if not queries[0].startswith("query: "):
print("Unexpected query format in model card for", model_id)
breakpoint()
queries = [query.split("query: ")[-1].strip() for query in queries]
documents = [doc.split("passage: ")[-1].strip() for doc in documents]
query_embeddings = model.encode_query(queries, normalize_embeddings=True)
doc_embeddings = model.encode_document(documents, normalize_embeddings=True)
shapes = f"# {query_embeddings.shape} {doc_embeddings.shape}"
similarities = model.similarity(query_embeddings, doc_embeddings)
similarities = "# " + str(similarities).replace("\n", "\n# ")
content = content[:match.start()] + ST_SNIPPET_TEMPLATE.format(
model_name=model_id,
query1=repr(queries[0]),
query2=repr(queries[1]),
document1=repr(documents[0]),
document2=repr(documents[1]),
shapes=shapes,
similarities=similarities,
) + content[match.end():]
else:
print("No match found in model card for", model_id)
breakpoint()
model_card.content = content
model_card.validate("model")
model._model_card_text = str(model_card)
# model.push_to_hub(model_id.replace("clips/", "tomaarsen/"), private=True)
url = model.push_to_hub(model_id, create_pr=True)
print("Pushed", model_id, "->", url)
breakpoint()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment