Skip to content

Instantly share code, notes, and snippets.

@mikegc-aws
Created March 28, 2024 00:29
Show Gist options
  • Save mikegc-aws/d632ac5d72d2d04a40d1526fa62a336e to your computer and use it in GitHub Desktop.
Save mikegc-aws/d632ac5d72d2d04a40d1526fa62a336e to your computer and use it in GitHub Desktop.
Simple example of how RAG works under the hood: https://www.youtube.com/watch?v=FQhksZ87Ncg
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import boto3\n",
"import json\n",
"import faiss \n",
"import numpy as np\n",
"from jinja2 import Template"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"bedrock_runtime_client = boto3.client(\"bedrock-runtime\", region_name=\"us-west-2\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Facts\n",
"- These facts represent a repository of information. In real systems these would represent a large collection of documents."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"facts = [\n",
" \"The Sun accounts for about 99.86% of the Solar System's mass.\",\n",
" \"Mercury is the smallest planet in our Solar System.\",\n",
" \"Venus is the hottest planet in our Solar System, even hotter than Mercury.\",\n",
" \"Earth's rotation is gradually slowing down, approximately 17 milliseconds per hundred years.\",\n",
" \"The Moon is moving away from Earth at a rate of about 3.8 cm per year.\",\n",
" \"Mars has the largest dust storms in the Solar System, which can last for months.\",\n",
" \"Jupiter has the shortest day in the Solar System, completing one rotation in just under 10 hours.\",\n",
" \"Saturn's rings are primarily composed of ice particles with a smaller amount of rocky debris and dust.\",\n",
" \"Uranus rotates on its side, making it unique among the planets in our Solar System.\",\n",
" \"Neptune has the strongest winds in the Solar System, reaching speeds of up to 1,500 miles per hour.\",\n",
" \"Pluto, now classified as a dwarf planet, was once considered the ninth planet of our Solar System.\",\n",
" \"The Milky Way galaxy is estimated to contain 100–400 billion stars.\",\n",
" \"The observable universe is estimated to contain more than 2 trillion galaxies.\",\n",
" \"The Hubble Space Telescope has been one of the most productive scientific instruments ever built, providing deep insights into the cosmos.\",\n",
" \"Black holes have such strong gravitational pull that not even light can escape from them.\"\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def embed(text):\n",
" \"\"\"\n",
" Embeds a fact using the Amazon titan-embed-text-v1 model.\n",
"\n",
" Args:\n",
" fact (str): The fact to embed.\n",
"\n",
" Returns:\n",
" list: The embedding of the fact.\n",
" \"\"\"\n",
"\n",
" kwargs = {\n",
" \"modelId\": \"amazon.titan-embed-text-v1\",\n",
" \"contentType\": \"application/json\",\n",
" \"accept\": \"*/*\",\n",
" \"body\": json.dumps({\n",
" \"inputText\": text\n",
" })\n",
" }\n",
"\n",
" resp = bedrock_runtime_client.invoke_model(**kwargs)\n",
"\n",
" resp_body = json.loads(resp.get('body').read())\n",
" return resp_body.get('embedding')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# Create an array to store the embeddings.\n",
"embeddings_array = np.array([]).reshape(0, 1536)\n",
"\n",
"# Embed each fact.\n",
"for fact in facts:\n",
" embeddings = embed(fact)\n",
" embeddings_array = np.append(embeddings_array, np.array(embeddings).reshape(1, -1), axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Create a vector store (this is just an in memory vector store for this sample)\n",
"index = faiss.IndexFlatL2(1536)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"15\n"
]
}
],
"source": [
"# Add out vectors to the vector store. And print the number of vectors in the vector store.\n",
"index.add(embeddings_array)\n",
"print(index.ntotal)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# Set a question to ask. We also need to get an embedded version of the question. \n",
"query = \"What is the Milky Way?\"\n",
"embedded_query = embed(query)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[11 13 12 14]]\n",
"[[261.49744 396.79645 409.55994 490.46497]]\n"
]
}
],
"source": [
"# Using the vector store, find facts that that are the most similar to our question. Find 4 facts.\n",
"k = 4 \n",
"D, I = index.search(np.array([embedded_query]), k) \n",
"print(I) # <- The indexes of the relevant facts\n",
"print(D) # <- The distances of the relevant facts (or how similar they are)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"prompt_template = \"\"\"\n",
"\n",
"Human: Given the facts provided in the following list of facts between the <facts> tags, \n",
"find the answer to the question written between the <question> tags.\n",
"\n",
"<facts>\n",
"{%- for fact in facts %}\n",
" - `{{fact}}`{% endfor %}\n",
"</facts>\n",
"\n",
"<question>{{question}}</question>\n",
"\n",
"Provide an answer in full including parts of the question, so the answer can be understood in any context.\n",
"If the facts are not relevant to the question being asked, you must not provide any other opinion and you \n",
"must respond with \"Given the information I have access to, I am unable to answer that question at this time.\"\n",
"\n",
"Just provide the answer, nothing else.\n",
"\n",
"Assistant:\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# Use Jinja to fill out our prompt template, adding in all the facts.\n",
"\n",
"data = {\n",
" 'facts': [facts[index] for index in I[0]], \n",
" 'question': query\n",
"}\n",
"\n",
"template = Template(prompt_template)\n",
"prompt = template.render(data)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Human: Given the facts provided in the following list of facts between the <facts> tags, \n",
"find the answer to the question written between the <question> tags.\n",
"\n",
"<facts>\n",
" - `The Milky Way galaxy is estimated to contain 100–400 billion stars.`\n",
" - `The Hubble Space Telescope has been one of the most productive scientific instruments ever built, providing deep insights into the cosmos.`\n",
" - `The observable universe is estimated to contain more than 2 trillion galaxies.`\n",
" - `Black holes have such strong gravitational pull that not even light can escape from them.`\n",
"</facts>\n",
"\n",
"<question>What is the Milky Way?</question>\n",
"\n",
"Provide an answer in full including parts of the question, so the answer can be understood in any context.\n",
"If the facts are not relevant to the question being asked, you must not provide any other opinion and you \n",
"must respond with \"Given the information I have access to, I am unable to answer that question at this time.\"\n",
"\n",
"Just provide the answer, nothing else.\n",
"\n",
"Assistant:\n"
]
}
],
"source": [
"# Preview the prompt\n",
"print(prompt)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Send the prompt to a LLM to figure out the answer to our question."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"kwargs = {\n",
" \"modelId\": \"anthropic.claude-v2\",\n",
" \"contentType\": \"application/json\",\n",
" \"accept\": \"*/*\",\n",
" \"body\": json.dumps(\n",
" {\n",
" \"prompt\": prompt, \n",
" \"max_tokens_to_sample\": 300,\n",
" \"temperature\": 0,\n",
" \"top_k\": 250,\n",
" \"top_p\": 0.999,\n",
" \"stop_sequences\": [\n",
" \"\\n\\nHuman:\"\n",
" ],\n",
" \"anthropic_version\": \"bedrock-2023-05-31\"\n",
" }\n",
" )\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" The Milky Way is a galaxy estimated to contain 100–400 billion stars.\n"
]
}
],
"source": [
"response = bedrock_runtime_client.invoke_model(**kwargs)\n",
"response_body = json.loads(response.get('body').read())\n",
"\n",
"generation = response_body['completion']\n",
"print(generation)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "python-dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment