Skip to content

Instantly share code, notes, and snippets.

@nan-wang
Created October 22, 2024 09:01
Show Gist options
  • Save nan-wang/f0bdd6ed3ef3177ce524c0f6cc615be8 to your computer and use it in GitHub Desktop.
Save nan-wang/f0bdd6ed3ef3177ce524c0f6cc615be8 to your computer and use it in GitHub Desktop.
A demo of using the classify API from Jina AI for semantic routing
import time
from typing import Any, Optional, Dict, List
import requests
from langchain.chains.router.llm_router import RouterChain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from pydantic import SecretStr, model_validator
JINA_API_URL: str = "http://api.jina.ai/v1/classify"
class JinaClassifierRouterChain(RouterChain):
session: Any
model_name: str
routing_keys: List[str] = ["query"]
jina_api_key: Optional[SecretStr] = None
labels: Optional[List[str]] = None
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the LLM chain prompt expects.
:meta private:
"""
return self.routing_keys
@model_validator(mode="before")
@classmethod
def validate_model(cls, values: Dict) -> Any:
try:
jina_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "jina_api_key", "JINA_API_KEY")
)
except ValueError as original_exc:
try:
jina_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "jina_auth_token", "JINA_AUTH_TOKEN")
)
except ValueError:
raise original_exc
session = requests.Session()
session.headers.update(
{
"Authorization": f"Bearer {jina_api_key.get_secret_value()}",
"Accept-Encoding": "identity",
"Content-type": "application/json",
}
)
values["session"] = session
return values
def _classify(self, texts: List[str], labels: List[str]) -> List[str]:
resp = self.session.post( # type: ignore
JINA_API_URL,
json={"input": texts, "labels": labels, "model": self.model_name},
).json()
if "data" not in resp:
raise RuntimeError(resp["detail"])
results = resp["data"]
# Sort resulting embeddings by index
sorted_results = sorted(results, key=lambda e: e["index"]) # type: ignore
# Return just the embeddings
return [r["prediction"] for r in sorted_results]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_input = ", ".join(inputs[k] for k in self.routing_keys)
_labels = inputs.get("labels", self.labels)
results = self._classify(
[
_input,
],
_labels,
)
return {"next_inputs": inputs, "destination": results[0]}
class PromptFactory:
physics_template = """You are a very smart physics professor. \
You are great at answering questions about physics in a concise and easy to understand manner. \
When you don't know the answer to a question you admit that you don't know.
Here is a question:
{input}"""
math_template = """You are a very good mathematician. You are great at answering math questions. \
You are so good because you are able to break down hard problems into their component parts, \
answer the component parts, and then put them together to answer the broader question.
Here is a question:
{input}"""
prompt_infos = [
{
"name": "physics",
"description": "Good for questions about physics",
"prompt_template": physics_template,
},
{
"name": "math",
"description": "Good for questions about math",
"prompt_template": math_template,
},
]
if __name__ == "__main__":
# Put here your API key or define it in your environment
# os.environ["JINA_API_KEY"] = '<key>'
# prompts from https://github.com/langchain-ai/rag-from-scratch/blob/main/rag_from_scratch_10_and_11.ipynb
questions = [
# Physics Questions
"How do potential energy and kinetic energy relate to the total mechanical energy of an object?",
"What is the principle of conservation of energy, and how does it apply to a roller coaster ride?",
"How do you calculate the work done on an object when a force is applied over a distance?",
"What are the differences between transverse and longitudinal waves, and what are some examples of each?",
"How does the Doppler effect explain the change in pitch of a sound as its source moves toward or away from you?",
# Math Questions
"How do you find the slope of a line given two points on the line?",
"What is the Pythagorean Theorem, and how is it used to find the length of the sides of a right triangle?",
"What is the difference between permutations and combinations, and when would you use each?",
"How do you calculate the area and circumference of a circle, and what are the key formulas?",
"What are the properties of similar triangles, and how can you use them to solve problems involving proportions?",
]
beg = time.time()
prompt_factory = PromptFactory()
chain = JinaClassifierRouterChain(
model_name="jina-embeddings-v3",
labels=[f"{r['name']}" for r in prompt_factory.prompt_infos],
)
for q in questions:
result = chain.invoke(q)
print(result["destination"])
end = time.time()
print(f"time elapsed: {end - beg}s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment