Created
September 24, 2025 16:08
-
-
Save tomaarsen/f02b628162b8d49a9f93d40758af6ef3 to your computer and use it in GitHub Desktop.
Script to update all E5-NL models to be nicely integrated with Sentence Transformers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import re | |
| from huggingface_hub import get_collection, ModelCard | |
| from sentence_transformers import SentenceTransformer | |
| from sentence_transformers.models import Normalize | |
| collection = get_collection(collection_slug="clips/e5-nl-68be9d3760240ce5c7d9f831") | |
| ST_SNIPPET_PATTERN = r"""\ | |
| from sentence_transformers import SentenceTransformer | |
| model = SentenceTransformer\((?:'|")([a-zA-Z0-9_\/\.-]+?)(?:'|")\) | |
| input_texts = \[ | |
| (?:'|")(.+?)(?:'|"), | |
| (?:'|")(.+?)(?:'|"), | |
| (?:'|")(.+?)(?:'|"), | |
| (?:'|")(.+?)(?:'|") | |
| \] | |
| embeddings = model.encode\(input_texts, normalize_embeddings=True\) | |
| """ | |
| ST_SNIPPET_TEMPLATE = """ | |
| from sentence_transformers import SentenceTransformer | |
| # Load the model from Hugging Face | |
| model = SentenceTransformer("{model_name}") | |
| # Perform inference using encode_query/encode_document for retrieval, | |
| # or encode_query for general purpose embeddings. Prompt prefixes | |
| # are automatically added with these two methods. | |
| queries = [ | |
| {query1}, | |
| {query2}, | |
| ] | |
| documents = [ | |
| {document1}, | |
| {document2}, | |
| ] | |
| query_embeddings = model.encode_query(queries) | |
| document_embeddings = model.encode_document(documents) | |
| print(query_embeddings.shape, document_embeddings.shape) | |
| {shapes} | |
| similarities = model.similarity(query_embeddings, document_embeddings) | |
| {similarities} | |
| """ | |
| FINISHED_MODELS = [] | |
| for item in collection.items: | |
| if item.item_type != "model": | |
| continue | |
| model_id = item.item_id | |
| if model_id in FINISHED_MODELS: | |
| continue | |
| model = SentenceTransformer( | |
| model_id, | |
| prompts={ | |
| "query": "query: ", | |
| "document": "passage: ", | |
| }, | |
| ) | |
| model.add_module("2", Normalize()) | |
| model_card = ModelCard.load(model_id) | |
| model_card.data.library_name = "sentence-transformers" | |
| model_card.data.language = "nl" | |
| tags = model_card.data.tags or [] | |
| if "transformers" not in tags: | |
| tags.append("transformers") | |
| model_card.data.tags = tags | |
| content = model_card.content | |
| match = re.search(ST_SNIPPET_PATTERN, content) | |
| if match: | |
| model_name = match.group(1) | |
| queries = [match.group(2), match.group(3)] | |
| documents = [match.group(4), match.group(5)] | |
| if not queries[0].startswith("query: "): | |
| print("Unexpected query format in model card for", model_id) | |
| breakpoint() | |
| queries = [query.split("query: ")[-1].strip() for query in queries] | |
| documents = [doc.split("passage: ")[-1].strip() for doc in documents] | |
| query_embeddings = model.encode_query(queries, normalize_embeddings=True) | |
| doc_embeddings = model.encode_document(documents, normalize_embeddings=True) | |
| shapes = f"# {query_embeddings.shape} {doc_embeddings.shape}" | |
| similarities = model.similarity(query_embeddings, doc_embeddings) | |
| similarities = "# " + str(similarities).replace("\n", "\n# ") | |
| content = content[:match.start()] + ST_SNIPPET_TEMPLATE.format( | |
| model_name=model_id, | |
| query1=repr(queries[0]), | |
| query2=repr(queries[1]), | |
| document1=repr(documents[0]), | |
| document2=repr(documents[1]), | |
| shapes=shapes, | |
| similarities=similarities, | |
| ) + content[match.end():] | |
| else: | |
| print("No match found in model card for", model_id) | |
| breakpoint() | |
| model_card.content = content | |
| model_card.validate("model") | |
| model._model_card_text = str(model_card) | |
| # model.push_to_hub(model_id.replace("clips/", "tomaarsen/"), private=True) | |
| url = model.push_to_hub(model_id, create_pr=True) | |
| print("Pushed", model_id, "->", url) | |
| breakpoint() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment