Skip to content

Instantly share code, notes, and snippets.

@rgstephens
Created September 5, 2024 22:35
Show Gist options
  • Save rgstephens/e28eba07b9f638e25a205957cae346bf to your computer and use it in GitHub Desktop.
Save rgstephens/e28eba07b9f638e25a205957cae346bf to your computer and use it in GitHub Desktop.
Rasa Custom Information Retriever for Kapa
from typing import TYPE_CHECKING, Any, List, Text, Dict, Optional
from dataclasses import dataclass
import aiohttp
import urllib.parse
import structlog
from rasa.utils.endpoints import EndpointConfig
from rasa.core.information_retrieval import (
SearchResultList,
InformationRetrieval,
InformationRetrievalException,
)
if TYPE_CHECKING:
# from langchain.schema import Document
from langchain.schema.embeddings import Embeddings
structlogger = structlog.get_logger()
class SearchResult:
def __init__(self, text: str, metadata: Dict, score: Optional[float] = None):
self.text = text
self.metadata = metadata
self.score = score
@dataclass
class SearchResultList:
results: List[SearchResult]
metadata: Dict
class KapaInformationRetrievalException(InformationRetrievalException):
"""Exception raised for errors in the Kapa."""
def __init__(self, message: str) -> None:
self.message = message
super().__init__()
def __str__(self) -> str:
return self.base_message + self.message + f"{self.__cause__}"
class Kapa(InformationRetrieval):
"""Kapa implementation."""
def __init__(
self,
embeddings: "Embeddings",
):
structlogger.debug("kapa.__init__")
self.embeddings = embeddings
self.token = None
self.url = None
def connect(self, config: EndpointConfig) -> None:
"""Setup to Kapa values."""
config_dict = vars(config)
structlogger.debug("kapa.connect", config=config_dict, url=config.url, token=f"*****{config.token[:-4]}")
self.token = config.token
self.url = config.url
if not self.url:
raise KapaInformationRetrievalException(
f"Kapa URL not set. config: {config}"
)
if not self.token:
raise KapaInformationRetrievalException(
"Kapa token not set."
)
structlogger.debug("kapa.connect", url=self.url)
def kapa_to_searchresult(self, kapa_json: dict) -> SearchResultList:
structlogger.debug("kapa.kapa_to_searchresult", kapa_json=kapa_json, type=type(kapa_json))
# Extract answer text
answer_text = kapa_json.get("answer")
# Extract metadata from the first relevant source
if kapa_json["relevant_sources"]:
metadata = kapa_json["relevant_sources"][0]
else:
metadata = {}
structlogger.debug("kapa.kapa_to_searchresult", answer_text=answer_text, metadata=metadata)
# Create SearchResult object
search_result = SearchResult(text=answer_text, metadata=metadata)
# Create SearchResultList object
search_result_list = SearchResultList(results=[search_result], metadata={})
return search_result_list
async def kapa_query(self, query_string: str):
structlogger.debug("kapa.kapa_query", url=self.url, query_string=query_string)
url = f"{self.url}?query={urllib.parse.quote(query_string)}"
headers = {
'X-API-TOKEN': self.token
}
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
response_json = await response.json()
structlogger.debug("kapa.kapa_query_test", response_json=response_json, response_type=type(response_json), response_json_type=type(response_json))
# Check if the response status is OK
if response.status == 200:
return await response.json() # Assuming the response is in JSON format
else:
# Handle HTTP errors
return {"error": f"HTTP Error when calling the knowledge base: {response.status}, url: {self.url}, headers: {headers}, response: {response}"}
# return response_json
except Exception as e:
raise KapaInformationRetrievalException(
f"Kapa search failed. Encountered error: {str(e)}, url: {self.url}, headers: {headers}"
) from e
async def search(
self, query: Text, tracker_state: dict[Text, Any], threshold: float = 0.0
) -> SearchResultList:
"""Search for a document in the Qdrant vector store.
Args:
query: The query to search for.
threshold: minimum similarity score to consider a document a match.
Returns:
SearchResultList: A list of documents that match the query.
@dataclass
class SearchResult:
text: str
metadata: dict
score: Optional[float] = None
@dataclass
class SearchResultList:
results: List[SearchResult]
metadata: dict
You can use the class method SearchResultList.from_document_list() to convert from a [Langchain Document](https://python.langchain.com/v0.2/docs/integrations/document_loaders/copypaste/) object type.
"""
structlogger.debug("kapa.search", query=query)
response = await self.kapa_query(query)
structlogger.debug("kapa.search", response=response, type=type(response))
result = self.kapa_to_searchresult(response)
structlogger.debug("kapa.search", result=result)
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment