Created
September 5, 2024 22:35
-
-
Save rgstephens/e28eba07b9f638e25a205957cae346bf to your computer and use it in GitHub Desktop.
Rasa Custom Information Retriever for Kapa
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import TYPE_CHECKING, Any, List, Text, Dict, Optional | |
from dataclasses import dataclass | |
import aiohttp | |
import urllib.parse | |
import structlog | |
from rasa.utils.endpoints import EndpointConfig | |
from rasa.core.information_retrieval import ( | |
SearchResultList, | |
InformationRetrieval, | |
InformationRetrievalException, | |
) | |
if TYPE_CHECKING: | |
# from langchain.schema import Document | |
from langchain.schema.embeddings import Embeddings | |
structlogger = structlog.get_logger() | |
class SearchResult: | |
def __init__(self, text: str, metadata: Dict, score: Optional[float] = None): | |
self.text = text | |
self.metadata = metadata | |
self.score = score | |
@dataclass | |
class SearchResultList: | |
results: List[SearchResult] | |
metadata: Dict | |
class KapaInformationRetrievalException(InformationRetrievalException): | |
"""Exception raised for errors in the Kapa.""" | |
def __init__(self, message: str) -> None: | |
self.message = message | |
super().__init__() | |
def __str__(self) -> str: | |
return self.base_message + self.message + f"{self.__cause__}" | |
class Kapa(InformationRetrieval): | |
"""Kapa implementation.""" | |
def __init__( | |
self, | |
embeddings: "Embeddings", | |
): | |
structlogger.debug("kapa.__init__") | |
self.embeddings = embeddings | |
self.token = None | |
self.url = None | |
def connect(self, config: EndpointConfig) -> None: | |
"""Setup to Kapa values.""" | |
config_dict = vars(config) | |
structlogger.debug("kapa.connect", config=config_dict, url=config.url, token=f"*****{config.token[:-4]}") | |
self.token = config.token | |
self.url = config.url | |
if not self.url: | |
raise KapaInformationRetrievalException( | |
f"Kapa URL not set. config: {config}" | |
) | |
if not self.token: | |
raise KapaInformationRetrievalException( | |
"Kapa token not set." | |
) | |
structlogger.debug("kapa.connect", url=self.url) | |
def kapa_to_searchresult(self, kapa_json: dict) -> SearchResultList: | |
structlogger.debug("kapa.kapa_to_searchresult", kapa_json=kapa_json, type=type(kapa_json)) | |
# Extract answer text | |
answer_text = kapa_json.get("answer") | |
# Extract metadata from the first relevant source | |
if kapa_json["relevant_sources"]: | |
metadata = kapa_json["relevant_sources"][0] | |
else: | |
metadata = {} | |
structlogger.debug("kapa.kapa_to_searchresult", answer_text=answer_text, metadata=metadata) | |
# Create SearchResult object | |
search_result = SearchResult(text=answer_text, metadata=metadata) | |
# Create SearchResultList object | |
search_result_list = SearchResultList(results=[search_result], metadata={}) | |
return search_result_list | |
async def kapa_query(self, query_string: str): | |
structlogger.debug("kapa.kapa_query", url=self.url, query_string=query_string) | |
url = f"{self.url}?query={urllib.parse.quote(query_string)}" | |
headers = { | |
'X-API-TOKEN': self.token | |
} | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.get(url, headers=headers) as response: | |
response_json = await response.json() | |
structlogger.debug("kapa.kapa_query_test", response_json=response_json, response_type=type(response_json), response_json_type=type(response_json)) | |
# Check if the response status is OK | |
if response.status == 200: | |
return await response.json() # Assuming the response is in JSON format | |
else: | |
# Handle HTTP errors | |
return {"error": f"HTTP Error when calling the knowledge base: {response.status}, url: {self.url}, headers: {headers}, response: {response}"} | |
# return response_json | |
except Exception as e: | |
raise KapaInformationRetrievalException( | |
f"Kapa search failed. Encountered error: {str(e)}, url: {self.url}, headers: {headers}" | |
) from e | |
async def search( | |
self, query: Text, tracker_state: dict[Text, Any], threshold: float = 0.0 | |
) -> SearchResultList: | |
"""Search for a document in the Qdrant vector store. | |
Args: | |
query: The query to search for. | |
threshold: minimum similarity score to consider a document a match. | |
Returns: | |
SearchResultList: A list of documents that match the query. | |
@dataclass | |
class SearchResult: | |
text: str | |
metadata: dict | |
score: Optional[float] = None | |
@dataclass | |
class SearchResultList: | |
results: List[SearchResult] | |
metadata: dict | |
You can use the class method SearchResultList.from_document_list() to convert from a [Langchain Document](https://python.langchain.com/v0.2/docs/integrations/document_loaders/copypaste/) object type. | |
""" | |
structlogger.debug("kapa.search", query=query) | |
response = await self.kapa_query(query) | |
structlogger.debug("kapa.search", response=response, type=type(response)) | |
result = self.kapa_to_searchresult(response) | |
structlogger.debug("kapa.search", result=result) | |
return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment