Created
October 22, 2024 09:01
-
-
Save nan-wang/f0bdd6ed3ef3177ce524c0f6cc615be8 to your computer and use it in GitHub Desktop.
A demo of using the classify API from Jina AI for semantic routing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from typing import Any, Optional, Dict, List | |
import requests | |
from langchain.chains.router.llm_router import RouterChain | |
from langchain_core.callbacks import CallbackManagerForChainRun | |
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env | |
from pydantic import SecretStr, model_validator | |
JINA_API_URL: str = "http://api.jina.ai/v1/classify" | |
class JinaClassifierRouterChain(RouterChain): | |
session: Any | |
model_name: str | |
routing_keys: List[str] = ["query"] | |
jina_api_key: Optional[SecretStr] = None | |
labels: Optional[List[str]] = None | |
@property | |
def input_keys(self) -> List[str]: | |
"""Will be whatever keys the LLM chain prompt expects. | |
:meta private: | |
""" | |
return self.routing_keys | |
@model_validator(mode="before") | |
@classmethod | |
def validate_model(cls, values: Dict) -> Any: | |
try: | |
jina_api_key = convert_to_secret_str( | |
get_from_dict_or_env(values, "jina_api_key", "JINA_API_KEY") | |
) | |
except ValueError as original_exc: | |
try: | |
jina_api_key = convert_to_secret_str( | |
get_from_dict_or_env(values, "jina_auth_token", "JINA_AUTH_TOKEN") | |
) | |
except ValueError: | |
raise original_exc | |
session = requests.Session() | |
session.headers.update( | |
{ | |
"Authorization": f"Bearer {jina_api_key.get_secret_value()}", | |
"Accept-Encoding": "identity", | |
"Content-type": "application/json", | |
} | |
) | |
values["session"] = session | |
return values | |
def _classify(self, texts: List[str], labels: List[str]) -> List[str]: | |
resp = self.session.post( # type: ignore | |
JINA_API_URL, | |
json={"input": texts, "labels": labels, "model": self.model_name}, | |
).json() | |
if "data" not in resp: | |
raise RuntimeError(resp["detail"]) | |
results = resp["data"] | |
# Sort resulting embeddings by index | |
sorted_results = sorted(results, key=lambda e: e["index"]) # type: ignore | |
# Return just the embeddings | |
return [r["prediction"] for r in sorted_results] | |
def _call( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, Any]: | |
_input = ", ".join(inputs[k] for k in self.routing_keys) | |
_labels = inputs.get("labels", self.labels) | |
results = self._classify( | |
[ | |
_input, | |
], | |
_labels, | |
) | |
return {"next_inputs": inputs, "destination": results[0]} | |
class PromptFactory: | |
physics_template = """You are a very smart physics professor. \ | |
You are great at answering questions about physics in a concise and easy to understand manner. \ | |
When you don't know the answer to a question you admit that you don't know. | |
Here is a question: | |
{input}""" | |
math_template = """You are a very good mathematician. You are great at answering math questions. \ | |
You are so good because you are able to break down hard problems into their component parts, \ | |
answer the component parts, and then put them together to answer the broader question. | |
Here is a question: | |
{input}""" | |
prompt_infos = [ | |
{ | |
"name": "physics", | |
"description": "Good for questions about physics", | |
"prompt_template": physics_template, | |
}, | |
{ | |
"name": "math", | |
"description": "Good for questions about math", | |
"prompt_template": math_template, | |
}, | |
] | |
if __name__ == "__main__": | |
# Put here your API key or define it in your environment | |
# os.environ["JINA_API_KEY"] = '<key>' | |
# prompts from https://github.com/langchain-ai/rag-from-scratch/blob/main/rag_from_scratch_10_and_11.ipynb | |
questions = [ | |
# Physics Questions | |
"How do potential energy and kinetic energy relate to the total mechanical energy of an object?", | |
"What is the principle of conservation of energy, and how does it apply to a roller coaster ride?", | |
"How do you calculate the work done on an object when a force is applied over a distance?", | |
"What are the differences between transverse and longitudinal waves, and what are some examples of each?", | |
"How does the Doppler effect explain the change in pitch of a sound as its source moves toward or away from you?", | |
# Math Questions | |
"How do you find the slope of a line given two points on the line?", | |
"What is the Pythagorean Theorem, and how is it used to find the length of the sides of a right triangle?", | |
"What is the difference between permutations and combinations, and when would you use each?", | |
"How do you calculate the area and circumference of a circle, and what are the key formulas?", | |
"What are the properties of similar triangles, and how can you use them to solve problems involving proportions?", | |
] | |
beg = time.time() | |
prompt_factory = PromptFactory() | |
chain = JinaClassifierRouterChain( | |
model_name="jina-embeddings-v3", | |
labels=[f"{r['name']}" for r in prompt_factory.prompt_infos], | |
) | |
for q in questions: | |
result = chain.invoke(q) | |
print(result["destination"]) | |
end = time.time() | |
print(f"time elapsed: {end - beg}s") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment