Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import streamlit as st | |
| from dotenv import load_dotenv | |
| from langsmith import traceable | |
| from langsmith.wrappers import wrap_openai | |
| import openai | |
| import asyncio | |
| import threading | |
| from PIL import Image | |
| import io | |
| import json | |
| import queue | |
| import logging | |
| import time | |
| # Load environment variables | |
| load_dotenv() | |
| openai.api_key = os.getenv("openai.api_key") | |
| LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY") | |
| ASSISTANT_ID = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A") | |
| if not all([openai.api_key, LANGSMITH_API_KEY, ASSISTANT_ID]): | |
| raise ValueError("Please set openai.api_key, LANGSMITH_API_KEY, and ASSISTANT_ID in your .env file.") | |
| # Initialize logging | |
| logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize Langsmith's traceable OpenAI client | |
| wrapped_openai = wrap_openai(openai.Client(api_key=openai.api_key, api_base="https://api.openai.com")) | |
| # Initialize Langsmith client (ensure you have configured Langsmith correctly) | |
| # Assuming Langsmith uses environment variables or configuration files for setup | |
| # If not, initialize it here accordingly | |
| # Define a traceable function to handle Assistant interactions | |
| def create_run(thread_id: str, assistant_id: str) -> openai.beta.RunsStream: | |
| """ | |
| Creates a streaming run with the Assistant. | |
| """ | |
| return wrapped_openai.beta.threads.runs.stream( | |
| thread_id=thread_id, | |
| assistant_id=assistant_id, | |
| model="gpt-4o", # Replace with your desired model | |
| stream=True | |
| ) | |
| # Function to remove citations as per your original code | |
| def remove_citation(text: str) -> str: | |
| pattern = r"γ\d+β \w+γ" | |
| return re.sub(pattern, "π", text) | |
| # Initialize session state for messages, thread_id, and tool_requests | |
| if "messages" not in st.session_state: | |
| st.session_state["messages"] = [] | |
| if "thread_id" not in st.session_state: | |
| st.session_state["thread_id"] = None | |
| if "tool_requests" not in st.session_state: | |
| st.session_state["tool_requests"] = queue.Queue() | |
| tool_requests = st.session_state["tool_requests"] | |
| # Initialize Streamlit page | |
| st.set_page_config(page_title="Solution Specifier A", layout="centered") | |
| st.title("Solution Specifier A") | |
| # Display existing messages | |
| for msg in st.session_state["messages"]: | |
| if msg["role"] == "user": | |
| with st.chat_message("user"): | |
| st.write(msg["content"]) | |
| else: | |
| with st.chat_message("assistant"): | |
| if isinstance(msg["content"], Image.Image): | |
| st.image(msg["content"]) | |
| else: | |
| st.write(msg["content"]) | |
| # Chat input widget | |
| user_input = st.chat_input("Type your message here...") | |
| # Function to handle tool requests (function calls) | |
| def handle_requires_action(tool_request): | |
| st.toast("Running a function", icon=":material/function:") | |
| tool_outputs = [] | |
| data = tool_request.data | |
| for tool in data.required_action.submit_tool_outputs.tool_calls: | |
| if tool.function.arguments: | |
| function_arguments = json.loads(tool.function.arguments) | |
| else: | |
| function_arguments = {} | |
| match tool.function.name: | |
| case "hello_world": | |
| logger.info("Calling hello_world function") | |
| answer = hello_world(**function_arguments) | |
| tool_outputs.append({"tool_call_id": tool.id, "output": answer}) | |
| case _: | |
| logger.error(f"Unrecognized function name: {tool.function.name}. Tool: {tool}") | |
| ret_val = { | |
| "status": "error", | |
| "message": f"Function name is not recognized. Ensure the correct function name and try again." | |
| } | |
| tool_outputs.append({"tool_call_id": tool.id, "output": json.dumps(ret_val)}) | |
| st.toast("Function completed", icon=":material/function:") | |
| return tool_outputs, data.thread_id, data.id | |
| # Example function that could be called by the Assistant | |
| def hello_world(name: str) -> str: | |
| time.sleep(2) # Simulate a long-running task | |
| return f"Hello {name}!" | |
| # Function to add assistant messages to session state | |
| def add_message_to_state_session(message): | |
| if len(message) > 0: | |
| st.session_state["messages"].append({"role": "assistant", "content": message}) | |
| # Function to process streamed data | |
| def data_streamer(stream): | |
| """ | |
| Stream data from the assistant. Text messages are yielded. Images and tool requests are put in the queue. | |
| """ | |
| logger.info("Starting data streamer") | |
| st.toast("Thinking...", icon=":material/emoji_objects:") | |
| content_produced = False | |
| try: | |
| for response in stream: | |
| event = response.event | |
| if event == "thread.message.delta": | |
| content = response.data.delta.content[0] | |
| if content.type == "text": | |
| value = content.text.value | |
| content_produced = True | |
| yield value | |
| elif content.type == "image_file": | |
| logger.info("Image file received") | |
| image_content = io.BytesIO(wrapped_openai.files.content(content.image_file.file_id).read()) | |
| img = Image.open(image_content) | |
| content_produced = True | |
| yield img | |
| elif event == "thread.run.requires_action": | |
| logger.info("Run requires action") | |
| tool_requests.put(response) | |
| if not content_produced: | |
| yield "[LLM requires a function call]" | |
| break | |
| elif event == "thread.run.failed": | |
| logger.error("Run failed") | |
| yield "[Run failed]" | |
| break | |
| finally: | |
| st.toast("Completed", icon=":material/emoji_objects:") | |
| logger.info("Finished data streamer") | |
| # Function to display the streamed response | |
| def display_stream(stream): | |
| with st.chat_message("assistant"): | |
| for content in data_streamer(stream): | |
| if isinstance(content, Image.Image): | |
| st.image(content) | |
| add_message_to_state_session(content) | |
| else: | |
| st.write(content) | |
| add_message_to_state_session(content) | |
| # Main function to handle user input and assistant response | |
| def main(): | |
| if user_input: | |
| # Add user message to session state | |
| st.session_state["messages"].append({"role": "user", "content": user_input}) | |
| # Display the user's message | |
| with st.chat_message("user"): | |
| st.write(user_input) | |
| # Create a new thread if it doesn't exist | |
| if st.session_state["thread_id"] is None: | |
| logger.info("Creating new thread") | |
| thread = wrapped_openai.beta.threads.create() | |
| st.session_state["thread_id"] = thread.id | |
| else: | |
| thread = wrapped_openai.beta.threads.retrieve(st.session_state["thread_id"]) | |
| # Add user message to the thread | |
| wrapped_openai.beta.threads.messages.create( | |
| thread_id=thread.id, | |
| role="user", | |
| content=user_input | |
| ) | |
| # Create a new run with streaming | |
| logger.info("Creating a new run with streaming") | |
| stream = create_run(thread.id, ASSISTANT_ID) | |
| # Start a separate thread to handle streaming to avoid blocking Streamlit | |
| stream_thread = threading.Thread(target=display_stream, args=(stream,)) | |
| stream_thread.start() | |
| # Handle tool requests if any | |
| while not tool_requests.empty(): | |
| logger.info("Handling tool requests") | |
| tool_request = tool_requests.get() | |
| tool_outputs, thread_id, run_id = handle_requires_action(tool_request) | |
| wrapped_openai.beta.threads.runs.submit_tool_outputs_stream( | |
| thread_id=thread_id, | |
| run_id=run_id, | |
| tool_outputs=tool_outputs | |
| ) | |
| # After handling, create a new stream to continue the conversation | |
| new_stream = create_run(thread_id, ASSISTANT_ID) | |
| display_stream(new_stream) | |
| # Run the main function | |
| main() |