Last active
June 3, 2023 03:37
-
-
Save j40903272/51b0d4e79580aa2eb9f8dac24799dc77 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import json | |
import logging | |
import threading | |
from typing import Union | |
from abc import ABC, abstractmethod | |
from langchain.chains.conversation.memory import ( | |
ConversationBufferMemory, | |
# ConversationBufferWindowMemory, | |
# ConversationSummaryBufferMemory | |
) | |
from langchain.chat_models import ChatOpenAI | |
from langchain.agents.tools import Tool | |
from langchain.agents.conversational.output_parser import ConvoOutputParser | |
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS | |
from langchain.agents.conversational_chat.prompt import ( | |
FORMAT_INSTRUCTIONS as CHAT_FORMAT_INSTRUCTIONS, | |
) | |
from langchain.schema import AgentAction, AgentFinish | |
from langchain.agents import AgentOutputParser | |
import functools | |
from langchain.chains import LLMChain | |
from langchain.chat_models import ChatOpenAI | |
from langchain.memory import ConversationBufferMemory | |
from langchain.prompts.chat import ( | |
PromptTemplate, | |
ChatPromptTemplate, | |
HumanMessagePromptTemplate, | |
) | |
import pandas as pd | |
class RobustRetryConvoOutputParser(AgentOutputParser): | |
def get_format_instructions(self) -> str: | |
return CHAT_FORMAT_INSTRUCTIONS | |
def parse_json(self, text: str): | |
try: | |
return [ | |
i | |
for i in re.findall(r"\s([{\[].*?[}\]])$", f" {text}", flags=re.DOTALL) | |
if i | |
][0] | |
except Exception: | |
return | |
def parse_markdown_code(self, text: str): | |
try: | |
for i in re.findall( | |
r"`{3}([\w]*)\n([\S\s]+?)\n`{3}", text, flags=re.DOTALL | |
)[0]: | |
if i: | |
tmp = self.parse_json(i) | |
if tmp: | |
return tmp | |
except Exception: | |
return | |
def parse_origin(self, text: str): | |
try: | |
cleaned_output = text.strip() | |
if "```json" in cleaned_output: | |
_, cleaned_output = cleaned_output.split("```json") | |
if "```" in cleaned_output: | |
cleaned_output, _ = cleaned_output.split("```") | |
if cleaned_output.startswith("```json"): | |
cleaned_output = cleaned_output[len("```json"):] | |
if cleaned_output.startswith("```"): | |
cleaned_output = cleaned_output[len("```"):] | |
if cleaned_output.endswith("```"): | |
cleaned_output = cleaned_output[: -len("```")] | |
return cleaned_output.strip() | |
except Exception: | |
return | |
def parse(self, text: str) -> Union[AgentAction, AgentFinish]: | |
try: | |
json_str = self.parse_json(text) | |
markdown_str = self.parse_markdown_code(text) | |
origin_str = self.parse_origin(text) | |
if json_str: | |
cleaned_output = json_str | |
elif markdown_str: | |
cleaned_output = markdown_str | |
if cleaned_output.startswith("```json"): | |
cleaned_output = cleaned_output[len("```json"):] | |
if cleaned_output.startswith("```"): | |
cleaned_output = cleaned_output[len("```"):] | |
if cleaned_output.endswith("```"): | |
cleaned_output = cleaned_output[: -len("```")] | |
elif origin_str: | |
cleaned_output = origin_str | |
else: | |
cleaned_output = text.strip() | |
response = json.loads(cleaned_output) | |
action, action_input = response["action"], response["action_input"] | |
if action == "Final Answer": | |
return AgentFinish({"output": action_input}, text) | |
else: | |
return AgentAction(action, action_input, text) | |
except Exception: | |
if "action" not in text and "action_input" not in text: | |
return AgentFinish({"output": text}, text) | |
else: | |
logger.warning(f"Not follow format instruction\n{cleaned_output}") | |
return AgentAction("RR", "", text) | |
class ConversationBot(ABC): | |
def __init__(self): | |
seed_everything(0) | |
# self.handler = LoggerCallbackHandler() | |
# self.async_handler = AsyncLoggerCallbackHandler() | |
self.llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0) | |
self.memory = ConversationBufferMemory( | |
memory_key="chat_history", output_key="output", return_messages=True | |
) | |
self.state = [] | |
self.tools = self.load_tools() | |
# retry when llm output parsing error | |
retrytool = retryActionTool() | |
func = retrytool.inference | |
self.tools.append( | |
Tool( | |
name=func.name, | |
description=func.description, | |
# coroutine=to_async(func), | |
func=func, | |
) | |
) | |
self.output_parser = RobustRetryConvoOutputParser() | |
self.agent = self.init_agent() | |
def run_text(self, text: str): | |
logger.info(f"User: {text}") | |
res = self.agent({"input": text}) | |
response = res["output"] | |
self.state += [(text, response)] | |
logger.info( | |
f"\nProcessed run_text, Input text: {text}\nCurrent state: {self.state}\n" | |
) | |
return response | |
async def arun_text(self, text: str): | |
logger.info(f"User: {text}") | |
res = await self.agent.acall({"input": text}) | |
response = res["output"] | |
self.state += [(text, response)] | |
logger.info( | |
f"\nProcessed run_text, Input text: {text}\nCurrent state: {self.state}\n" | |
) | |
return response | |
def _clear(self): | |
self.memory.clear | |
def init_agent(self): | |
input_variables = ["input", "agent_scratchpad", "chat_history", "dataset"] | |
return initialize_agent( | |
self.tools, | |
self.llm, | |
agent="chat-conversational-react-description", | |
verbose=True, | |
memory=self.memory, | |
return_intermediate_steps=True, | |
max_iterations=5, | |
# max_execution_time=60, | |
agent_kwargs={ | |
"prefix": AUTOML_CHATGPT_PREFIX, | |
"format_instructions": AUTOML_CHATGPT_FORMAT_INSTRUCTIONS, | |
"suffix": AUTOML_CHATGPT_SUFFIX, | |
"input_variables": input_variables, | |
"output_parser": self.output_parser, | |
}, | |
) | |
def load_tools(self): | |
tools = [] | |
self.load_artifacts() | |
for tool in [ | |
]: | |
func = tool.inference | |
tools.append( | |
Tool( | |
name=func.name, | |
description=func.description, | |
# coroutine=to_async(func), | |
func=func, | |
) | |
) | |
logger.info(f"tools: {[i.name for i in tools]}") | |
return tools | |
def load_artifacts(self): | |
return |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment