import os import logging import logging.config from typing import Any from uuid import uuid4, UUID import json import sys import gradio as gr from dotenv import load_dotenv from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage from langgraph.types import RunnableConfig from pydantic import BaseModel from pathlib import Path load_dotenv() # Check Gradio version print(f"Gradio version: {gr.__version__}") # There are tools set here dependent on environment variables from graph import graph, weak_model, search_enabled # noqa FOLLOWUP_QUESTION_NUMBER = 3 TRIM_MESSAGE_LENGTH = 16 # Includes tool messages USER_INPUT_MAX_LENGTH = 10000 # Characters # We need the same secret for data persistance # If you store sensitive data, you should store your secret in .env BROWSER_STORAGE_SECRET = "itsnosecret" try: with open('logging-config.json', 'r') as fh: config = json.load(fh) logging.config.dictConfig(config) except FileNotFoundError: # Fallback logging configuration logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def load_initial_greeting(filepath="greeting_prompt.txt") -> str: """ Loads the initial greeting message from a specified text file. """ try: with open(filepath, "r", encoding="utf-8") as f: return f.read().strip() except FileNotFoundError: logger.warning(f"Warning: Prompt file '{filepath}' not found.") return "Welcome to DIYO! I'm here to help you create amazing DIY projects. What would you like to build today?" async def chat_fn(user_input: str, history: dict, input_graph_state: dict, uuid: UUID, prompt: str, search_enabled: bool, download_website_text_enabled: bool): """ Args: user_input (str): The user's input message history (dict): The history of the conversation in gradio input_graph_state (dict): The current state of the graph. This includes tool call history uuid (UUID): The unique identifier for the current conversation. This can be used in conjunction with langgraph or for memory prompt (str): The system prompt Yields: str: The output message dict|Any: The final state of the graph bool|Any: Whether to trigger follow up questions """ try: logger.info(f"Processing user input: {user_input[:100]}...") # Initialize input_graph_state if None if input_graph_state is None: input_graph_state = {} input_graph_state["tools_enabled"] = { "download_website_text": download_website_text_enabled, "tavily_search_results_json": search_enabled, } if prompt: input_graph_state["prompt"] = prompt if input_graph_state.get("awaiting_human_input"): input_graph_state["messages"].append( ToolMessage( tool_call_id=input_graph_state.pop("human_assistance_tool_id"), content=user_input ) ) input_graph_state["awaiting_human_input"] = False else: # New user message if "messages" not in input_graph_state: input_graph_state["messages"] = [] input_graph_state["messages"].append( HumanMessage(user_input[:USER_INPUT_MAX_LENGTH]) ) input_graph_state["messages"] = input_graph_state["messages"][-TRIM_MESSAGE_LENGTH:] config = RunnableConfig( recursion_limit=20, run_name="user_chat", configurable={"thread_id": str(uuid)} ) output: str = "" final_state: dict | Any = {} waiting_output_seq: list[str] = [] async for stream_mode, chunk in graph.astream( input_graph_state, config=config, stream_mode=["values", "messages"], ): if stream_mode == "values": final_state = chunk if chunk.get("messages") and len(chunk["messages"]) > 0: last_message = chunk["messages"][-1] if hasattr(last_message, "tool_calls") and last_message.tool_calls: for msg_tool_call in last_message.tool_calls: tool_name: str = msg_tool_call['name'] if tool_name == "tavily_search_results_json": query = msg_tool_call['args']['query'] waiting_output_seq.append(f"🔍 Searching for '{query}'...") yield "\n".join(waiting_output_seq), gr.skip(), gr.skip() elif tool_name == "download_website_text": url = msg_tool_call['args']['url'] waiting_output_seq.append(f"📥 Downloading text from '{url}'...") yield "\n".join(waiting_output_seq), gr.skip(), gr.skip() elif tool_name == "human_assistance": query = msg_tool_call["args"]["query"] waiting_output_seq.append(f"🤖: {query}") # Save state to resume after user provides input final_state["awaiting_human_input"] = True final_state["human_assistance_tool_id"] = msg_tool_call["id"] # Indicate that human input is needed yield "\n".join(waiting_output_seq), final_state, True return # Pause execution, resume in next call else: waiting_output_seq.append(f"🔧 Running {tool_name}...") yield "\n".join(waiting_output_seq), gr.skip(), gr.skip() elif stream_mode == "messages": msg, metadata = chunk # Check for the correct node name from your graph node_name = metadata.get('langgraph_node', '') if node_name in ["brainstorming_node", "prompt_planning_node", "generate_3d_node", "assistant_node"]: current_chunk_text = "" if isinstance(msg.content, str): current_chunk_text = msg.content elif isinstance(msg.content, list): for block in msg.content: if isinstance(block, dict) and block.get("type") == "text": current_chunk_text += block.get("text", "") elif isinstance(block, str): current_chunk_text += block if current_chunk_text: output += current_chunk_text yield output, gr.skip(), gr.skip() # Final yield with complete response yield output + " ", dict(final_state), True except Exception as e: logger.exception("Exception occurred in chat_fn") user_error_message = "There was an error processing your request. Please try again." yield user_error_message, gr.skip(), False def clear(): """Clear the current conversation state""" return dict(), uuid4() class FollowupQuestions(BaseModel): """Model for langchain to use for structured output for followup questions""" questions: list[str] async def populate_followup_questions(end_of_chat_response: bool, messages: dict[str, str], uuid: UUID): """ This function gets called a lot due to the asynchronous nature of streaming Only populate followup questions if streaming has completed and the message is coming from the assistant """ if not end_of_chat_response or not messages or len(messages) == 0: return *[gr.skip() for _ in range(FOLLOWUP_QUESTION_NUMBER)], False # Check if the last message is from assistant if messages[-1]["role"] != "assistant": return *[gr.skip() for _ in range(FOLLOWUP_QUESTION_NUMBER)], False try: config = RunnableConfig( run_name="populate_followup_questions", configurable={"thread_id": str(uuid)} ) weak_model_with_config = weak_model.with_config(config) follow_up_questions = await weak_model_with_config.with_structured_output(FollowupQuestions).ainvoke([ ("system", f"suggest {FOLLOWUP_QUESTION_NUMBER} followup questions for the user to ask the assistant. Refrain from asking personal questions."), *messages, ]) if len(follow_up_questions.questions) != FOLLOWUP_QUESTION_NUMBER: logger.warning("Invalid number of followup questions generated") return *[gr.Button(visible=False) for _ in range(FOLLOWUP_QUESTION_NUMBER)], False buttons = [] for i in range(FOLLOWUP_QUESTION_NUMBER): buttons.append( gr.Button(follow_up_questions.questions[i], visible=True, elem_classes="chat-tab"), ) return *buttons, False except Exception as e: logger.error(f"Error generating followup questions: {e}") return *[gr.Button(visible=False) for _ in range(FOLLOWUP_QUESTION_NUMBER)], False async def summarize_chat(end_of_chat_response: bool, messages: dict, sidebar_summaries: dict, uuid: UUID): """Summarize chat for tab names""" should_return = ( not end_of_chat_response or not messages or len(messages) == 0 or messages[-1]["role"] != "assistant" or isinstance(sidebar_summaries, type(lambda x: x)) or uuid in sidebar_summaries ) if should_return: return gr.skip(), gr.skip() # Filter valid messages filtered_messages = [] for msg in messages: if isinstance(msg, dict) and msg.get("content") and msg["content"].strip(): filtered_messages.append(msg) # If we don't have any valid messages after filtering, provide a default summary if not filtered_messages: if uuid not in sidebar_summaries: sidebar_summaries[uuid] = "New Chat" return sidebar_summaries, False try: config = RunnableConfig( run_name="summarize_chat", configurable={"thread_id": str(uuid)} ) weak_model_with_config = weak_model.with_config(config) summary_response = await weak_model_with_config.ainvoke([ ("system", "summarize this chat in 7 tokens or less. Refrain from using periods"), *filtered_messages, ]) if uuid not in sidebar_summaries: sidebar_summaries[uuid] = summary_response.content[:50] # Limit length except Exception as e: logger.error(f"Error summarizing chat: {e}") if uuid not in sidebar_summaries: sidebar_summaries[uuid] = "Chat Session" return sidebar_summaries, False async def new_tab(uuid, gradio_graph, messages, tabs, prompt, sidebar_summaries): """Create a new chat tab""" new_uuid = uuid4() new_graph = {} # Save current tab if it has content if messages and len(messages) > 0: if uuid not in sidebar_summaries: sidebar_summaries, _ = await summarize_chat(True, messages, sidebar_summaries, uuid) tabs[uuid] = { "graph": gradio_graph, "messages": messages, "prompt": prompt, } # Clear suggestion buttons suggestion_buttons = [gr.Button(visible=False) for _ in range(FOLLOWUP_QUESTION_NUMBER)] # Load initial greeting for new chat greeting_text = load_initial_greeting() new_chat_messages_for_display = [{"role": "assistant", "content": greeting_text}] new_prompt = prompt if prompt else "You are a helpful DIY assistant." return new_uuid, new_graph, new_chat_messages_for_display, tabs, new_prompt, sidebar_summaries, *suggestion_buttons def switch_tab(selected_uuid, tabs, gradio_graph, uuid, messages, prompt): """Switch to a different chat tab""" try: # Save current state if there are messages if messages and len(messages) > 0: tabs[uuid] = { "graph": gradio_graph if gradio_graph else {}, "messages": messages, "prompt": prompt } if selected_uuid not in tabs: logger.error(f"Could not find the selected tab in tabs storage: {selected_uuid}") return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *[gr.Button(visible=False) for _ in range(FOLLOWUP_QUESTION_NUMBER)] selected_tab_state = tabs[selected_uuid] selected_graph = selected_tab_state.get("graph", {}) selected_messages = selected_tab_state.get("messages", []) selected_prompt = selected_tab_state.get("prompt", "You are a helpful DIY assistant.") suggestion_buttons = [gr.Button(visible=False) for _ in range(FOLLOWUP_QUESTION_NUMBER)] return selected_graph, selected_uuid, selected_messages, tabs, selected_prompt, *suggestion_buttons except Exception as e: logger.error(f"Error switching tabs: {e}") return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *[gr.Button(visible=False) for _ in range(FOLLOWUP_QUESTION_NUMBER)] def delete_tab(current_chat_uuid, selected_uuid, sidebar_summaries, tabs): """Delete a chat tab""" output_messages = gr.skip() # If deleting the current tab, clear the chatbot if current_chat_uuid == selected_uuid: output_messages = [] # Remove from storage if selected_uuid in tabs: del tabs[selected_uuid] if selected_uuid in sidebar_summaries: del sidebar_summaries[selected_uuid] return sidebar_summaries, tabs, output_messages def submit_edit_tab(selected_uuid, sidebar_summaries, text): """Submit edited tab name""" if text.strip(): sidebar_summaries[selected_uuid] = text.strip()[:50] # Limit length return sidebar_summaries, "" def load_mesh(mesh_file_name): """Load a 3D mesh file""" return mesh_file_name def get_sorted_3d_model_examples(): """Get sorted list of 3D model examples""" examples_dir = Path("./generated_3d_models") # Create directory if it doesn't exist examples_dir.mkdir(exist_ok=True) if not examples_dir.exists(): return [] # Get all 3D model files with desired extensions model_files = [ file for file in examples_dir.glob("*") if file.suffix.lower() in {".obj", ".glb", ".gltf"} ] # Sort files by creation time (latest first) try: sorted_files = sorted( model_files, key=lambda x: x.stat().st_ctime, reverse=True ) except (OSError, AttributeError): # Fallback to name sorting if stat fails sorted_files = sorted(model_files, key=lambda x: x.name, reverse=True) # Convert to format [[path1], [path2], ...] return [[str(file)] for file in sorted_files] CSS = """ footer {visibility: hidden} .followup-question-button {font-size: 12px } .chat-tab { font-size: 12px; padding-inline: 0; } .chat-tab.active { background-color: #654343; } #new-chat-button { background-color: #0f0f11; color: white; } .tab-button-control { min-width: 0; padding-left: 0; padding-right: 0; } .sidebar-collapsed { display: none !important; } .wrap.sidebar-parent { min-height: 2400px !important; height: 2400px !important; } #main-app { height: 4600px; overflow-y: auto; padding-top: 20px; } """ TRIGGER_CHATINTERFACE_BUTTON = """ function triggerChatButtonClick() { const chatTextbox = document.getElementById("chat-textbox"); if (!chatTextbox) { console.error("Error: Could not find element with id 'chat-textbox'"); return; } const button = chatTextbox.querySelector("button"); if (!button) { console.error("Error: No button found inside the chat-textbox element"); return; } button.click(); }""" if __name__ == "__main__": logger.info("Starting the DIYO interface") # Check if BrowserState is available has_browser_state = hasattr(gr, 'BrowserState') logger.info(f"BrowserState available: {has_browser_state}") if not has_browser_state: logger.warning("BrowserState not available in this Gradio version. Using regular State instead.") logger.warning("To use BrowserState, upgrade Gradio: pip install gradio>=4.0.0") with gr.Blocks(title="DIYO - DIY Assistant", fill_height=True, css=CSS, elem_id="main-app") as demo: # State management - Use BrowserState if available, otherwise regular State is_new_user_for_greeting = gr.State(True) if has_browser_state: current_prompt_state = gr.BrowserState( value="You are a helpful DIY assistant.", storage_key="current_prompt_state", secret=BROWSER_STORAGE_SECRET, ) current_uuid_state = gr.BrowserState( value=uuid4, storage_key="current_uuid_state", secret=BROWSER_STORAGE_SECRET, ) current_langgraph_state = gr.BrowserState( value=dict, storage_key="current_langgraph_state", secret=BROWSER_STORAGE_SECRET, ) sidebar_names_state = gr.BrowserState( value=dict, storage_key="sidebar_names_state", secret=BROWSER_STORAGE_SECRET, ) offloaded_tabs_data_storage = gr.BrowserState( value=dict, storage_key="offloaded_tabs_data_storage", secret=BROWSER_STORAGE_SECRET, ) chatbot_message_storage = gr.BrowserState( value=list, storage_key="chatbot_message_storage", secret=BROWSER_STORAGE_SECRET, ) else: # Fallback to regular State current_prompt_state = gr.State("You are a helpful DIY assistant.") current_uuid_state = gr.State(uuid4()) current_langgraph_state = gr.State({}) sidebar_names_state = gr.State({}) offloaded_tabs_data_storage = gr.State({}) chatbot_message_storage = gr.State([]) end_of_assistant_response_state = gr.State(False) # Header with gr.Row(elem_classes="header-margin"): gr.Markdown("""