Last active
December 13, 2024 20:42
-
-
Save cpsievert/0a0830b80a7bb4f24dcc1d6033ab387c to your computer and use it in GitHub Desktop.
Simple RAG with chatlas
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from chatlas import ChatAnthropic | |
from sentence_transformers import SentenceTransformer | |
# Initialize the SentenceTransformer model and chat client | |
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
chat = ChatAnthropic( | |
system_prompt=""" | |
You are a helpful AI assistant. Using the provided context, | |
answer the user's question. If you cannot answer the question based on the | |
context, say so. | |
""" | |
) | |
# Our list of documents (one document per list element) | |
documents = [ | |
"The Python programming language was created by Guido van Rossum.", | |
"Python is known for its simple, readable syntax.", | |
"Python supports multiple programming paradigms.", | |
] | |
# Compute embeddings for each document | |
embeddings = [ | |
embed_model.encode([doc])[0] for doc in documents | |
] | |
# Compute embedding for the user query | |
user_query = "Who created Python?" | |
query_embedding = embed_model.encode([user_query])[0] | |
# Calculate cosine similarity between the query and each document | |
similarities = np.dot(embeddings, query_embedding) / ( | |
np.linalg.norm(embeddings, axis=1) * np.linalg.norm(query_embedding) | |
) | |
# Get the top-k most similar documents | |
top_k = 3 | |
top_indices = np.argsort(similarities)[-top_k:][::-1] | |
# Give the model both the context and the question | |
chat.chat( | |
f"Context:\n{documents[top_indices[0]]}\n\nQuestion: {user_query}" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment