Skip to content

Instantly share code, notes, and snippets.

@duarteocarmo
Created November 6, 2024 20:11
Show Gist options
  • Save duarteocarmo/85435ce9209d1824f68e11b6126ab5c9 to your computer and use it in GitHub Desktop.
Save duarteocarmo/85435ce9209d1824f68e11b6126ab5c9 to your computer and use it in GitHub Desktop.
import streamlit as st
from litellm import completion, stream_chunk_builder
from loguru import logger
import json
import plotly.express as px
from enum import Enum
MODEL = "gpt-4o-mini"
class Roles(Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
class Avatar(Enum):
USER = ":material/engineering:"
ASSISTANT = ":material/cognition:"
TOOL = ":material/construction:"
if "messages" not in st.session_state:
st.session_state.messages = []
def create_plot(
x: list,
y: list,
plot_type: str = "line",
title: str = "Sample Plot",
):
"""Create a plot with customizable labels"""
if plot_type == "line":
fig = px.line(x=x, y=y, title=title)
elif plot_type == "scatter":
fig = px.scatter(x=x, y=y, title=title)
elif plot_type == "bar":
fig = px.bar(x=x, y=y, title=title)
else:
raise ValueError("Invalid plot type. Choose from 'line', 'scatter', or 'bar'.")
return fig
tools = [
{
"type": "function",
"function": {
"name": "create_plot",
"description": "Create a plot to visualize data",
"parameters": {
"type": "object",
"properties": {
"plot_type": {
"type": "string",
"enum": ["line", "scatter", "bar"],
"description": "The type of plot to create",
},
"title": {"type": "string", "description": "The title of the plot"},
"x": {
"type": "array",
"description": "The x-axis data",
"items": {"type": ["number", "string"]},
},
"y": {
"type": "array",
"description": "The y-axis data",
"items": {"type": "number"},
},
},
"required": ["x", "y"],
},
},
}
]
def stream_response(response):
"""Stream the LLM response and handle function calls"""
chunks = []
plot_data = None
full_content = ""
for chunk in response:
content = chunk.choices[0].delta.content
tool_calls = chunk.choices[0].delta.tool_calls
if tool_calls:
chunks.append(chunk)
if content:
full_content += content
yield content
if chunks:
with st.status("Creating visualization...", expanded=True) as status:
rebuilt_stream = stream_chunk_builder(chunks)
tool_call = rebuilt_stream.choices[0].message.tool_calls[0]
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
logger.info(f"Function call: {function_name} with args: {function_args}")
if function_name == "create_plot":
st.write(f"Generating {function_args.get('plot_type', 'line')} plot...")
fig = create_plot(**function_args)
st.plotly_chart(fig)
status.update(label="Visualization created!", state="complete")
# Store only the function arguments needed to recreate the plot
plot_data = function_args
st.session_state.current_response = {
"content": full_content,
"plot_data": plot_data,
}
return
st.session_state.current_response = {"content": full_content, "plot_data": None}
def generate_response(messages):
"""Generate a streaming response from the LLM"""
logger.info("Generating completion...")
stream = completion(
model=MODEL,
messages=messages,
stream=True,
tools=tools,
tool_choice="auto",
temperature=0.7,
)
response = st.write_stream(stream_response(stream))
logger.info("Generated completion.")
return st.session_state.current_response
st.set_page_config(
page_title="Chat Assistant with Visualization",
page_icon="📊",
layout="wide",
)
if len(st.session_state.messages) == 0:
st.session_state.messages.append(
{
"role": Roles.SYSTEM.value,
"content": """You are a helpful assistant that can create visualizations. When users ask for data visualization, use the create_plot function to generate appropriate plots.
Make sure to use meaningful axis labels and titles when creating plots. If the data points have specific meanings, use x_labels to label them appropriately.""",
}
)
if st.button("Reset Chat"):
st.session_state.messages = []
st.rerun()
# Display chat messages
for message in st.session_state.messages:
if message["role"] != Roles.SYSTEM.value:
avatar = (
Avatar.ASSISTANT.value
if message["role"] == "assistant"
else Avatar.USER.value
if message["role"] == "user"
else Avatar.TOOL.value
)
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message.get("content", ""))
# If there's plot data in the message, recreate and display the plot
if "plot_data" in message and message["plot_data"] is not None:
fig = create_plot(**message["plot_data"])
st.plotly_chart(fig)
if prompt := st.chat_input("Ask me to create a visualization or chat!"):
# Add user message to chat
st.session_state.messages.append({"role": Roles.USER.value, "content": prompt})
with st.chat_message(Roles.USER.value, avatar=Avatar.USER.value):
st.markdown(prompt)
with st.chat_message(Roles.ASSISTANT.value, avatar=Avatar.ASSISTANT.value):
response = generate_response(st.session_state.messages)
# Add assistant response to chat history with the plot data if it exists
st.session_state.messages.append(
{
"role": Roles.ASSISTANT.value,
"content": response["content"],
"plot_data": response.get("plot_data"),
}
)
st.markdown(
"""
<style>
.stChatMessage {
padding: 1rem;
}
</style>
""",
unsafe_allow_html=True,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment