Spaces:
Sleeping
Sleeping
| #pip install langchain_google_genai langgraph gradio | |
| import os | |
| import sys | |
| import typing | |
| from typing import Annotated, Literal, Iterable | |
| from typing_extensions import TypedDict | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.graph.message import add_messages | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_core.tools import tool | |
| from langchain_core.messages import AIMessage, ToolMessage, HumanMessage, BaseMessage, SystemMessage | |
| from random import randint | |
| import requests | |
| from bs4 import BeautifulSoup | |
| import openpyxl | |
| import wikipedia | |
| import pandas as pd | |
| import gradio as gr | |
| import logging | |
| class OrderState(TypedDict): | |
| """State representing the customer's order conversation.""" | |
| messages: Annotated[list, add_messages] | |
| order: list[str] | |
| finished: bool | |
| # System instruction for the Agent | |
| SYSINT = ( | |
| "system", | |
| "You are a general AI assistant. I will ask you a question." | |
| "The question requires a tool to solve. You must attempt to use at least one of the available tools before returning an answer." | |
| "Report your thoughts, and finish your answer with the following template: " | |
| "FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings." | |
| "If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise." | |
| "If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise." | |
| "If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string." | |
| "If a tool required for task completion is not functioning, return 0." | |
| ) | |
| WELCOME_MSG = "Welcome to my general-purpose AI agent. Type `q` to quit. How shall I fail to serve you today?" | |
| # Initialize the Google Gemini LLM | |
| llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest") | |
| def wikipedia_search_tool(title: str) -> str: | |
| """Provides an excerpt from a Wikipedia article with the given title.""" | |
| page = wikipedia.page(title, auto_suggest=False) | |
| return page.content[:3000] | |
| def media_tool(file_path: str) -> str: | |
| """Used for deciphering video and audio files.""" | |
| return "This tool hasn't been implemented yet. Please return 0 if the task cannot be solved without knowing the contents of this file." | |
| def internet_search_tool(search_query: str) -> str: | |
| """Does a google search with using the input as the search query. Returns a long batch of textual information related to the query.""" | |
| search_tool = DuckDuckGoSearchTool() | |
| result = search_tool(question) | |
| return result | |
| def webscraper_tool(url: str) -> str: | |
| """Returns the page's html content from the input url.""" | |
| response = requests.get(url, stream=True) | |
| if response.status_code == 200: | |
| soup = BeautifulSoup(response.content, 'html.parser') | |
| html_text = soup.get_text() | |
| return html_text | |
| else: | |
| raise Exception(f"Failed to retrieve the webpage. Status code: {response.status_code}") | |
| def read_excel_tool(file_path: str) -> str: | |
| """Returns the contents of an Excel file as a Pandas dataframe.""" | |
| df = pd.read_excel(path, engine = "openpyxl") | |
| return df | |
| def agent_node(state: OrderState) -> OrderState: | |
| """agent with tool handling.""" | |
| print(f"Messagelist sent to agent node: {[msg.content for msg in state.get('messages', [])]}") | |
| defaults = {"order": [], "finished": False} | |
| # Ensure we always have at least a system message | |
| if not state.get("messages", []): | |
| return defaults | state | {"messages": [SystemMessage(content=SYSINT), new_output]} | |
| try: | |
| # Prepend system instruction if not already present | |
| messages_with_system = [ | |
| SystemMessage(content=SYSINT) | |
| ] + state.get("messages", []) | |
| # Process messages through the LLM | |
| new_output = llm_with_tools.invoke(messages_with_system) | |
| return defaults | state | {"messages": [new_output]} | |
| except Exception as e: | |
| # Fallback if LLM processing fails | |
| return defaults | state | {"messages": [AIMessage(content=f"I'm having trouble processing that. {str(e)}")]} | |
| def interactive_tools_node(state: OrderState) -> OrderState: | |
| """Handles interactive tool calls.""" | |
| logging.info("interactive tools node") | |
| tool_msg = state.get("messages", [])[-1] | |
| order = state.get("order", []) | |
| outbound_msgs = [] | |
| for tool_call in tool_msg.tool_calls: | |
| tool_name = tool_call["name"] | |
| tool_args = tool_call["args"] | |
| if tool_name == "wikipedia_search_tool": | |
| print(str(tool_args)) | |
| page = wikipedia.page(tool_args.get("title"), auto_suggest=False) | |
| response = page.content[:3000] | |
| elif tool_name == "media_tool": | |
| print(str(tool_args)) | |
| response = "This tool hasn't been implemented yet. Please return 0 if the task cannot be solved without knowing the contents of this file." | |
| elif tool_name == "internet_search_tool": | |
| search_tool = DuckDuckGoSearchTool() | |
| response = search_tool(question)[:3000] | |
| elif tool_name == "webscraper_tool": | |
| response = requests.get(url, stream=True) | |
| if response.status_code == 200: | |
| soup = BeautifulSoup(response.content, 'html.parser') | |
| html_text = soup.get_text() | |
| response = html_text | |
| else: | |
| response = Exception(f"Failed to retrieve the webpage. Status code: {response.status_code}") | |
| elif tool_name == "read_excel_tool": | |
| df = pd.read_excel(path, engine = "openpyxl") | |
| response = df | |
| else: | |
| raise NotImplementedError(f'Unknown tool call: {tool_name}') | |
| outbound_msgs.append( | |
| ToolMessage( | |
| content=response, | |
| name=tool_name, | |
| tool_call_id=tool_call["id"], | |
| ) | |
| ) | |
| return {"messages": outbound_msgs, "order": order, "finished": False} | |
| def maybe_route_to_tools(state: OrderState) -> str: | |
| """Route between chat and tool nodes.""" | |
| if not (msgs := state.get("messages", [])): | |
| raise ValueError(f"No messages found when parsing state: {state}") | |
| msg = msgs[-1] | |
| if state.get("finished", False): | |
| print("from agent GOTO End node") | |
| return END | |
| elif hasattr(msg, "tool_calls") and len(msg.tool_calls) > 0: | |
| if any(tool["name"] in tool_node.tools_by_name.keys() for tool in msg.tool_calls): | |
| print("from agent GOTO tools node") | |
| return "tools" | |
| else: | |
| logging.info("from chatbot GOTO interactive tools node") | |
| return "interactive_tools" | |
| print("tool call failed, quitting") | |
| return "human" | |
| def human_node(state: OrderState) -> OrderState: | |
| """Handle user input.""" | |
| logging.info(f"Messagelist sent to human node: {[msg.content for msg in state.get('messages', [])]}") | |
| last_msg = state["messages"][-1] | |
| if last_msg.content.lower() in {"q", "quit", "exit", "goodbye"}: | |
| state["finished"] = True | |
| return state | |
| def maybe_exit_human_node(state: OrderState) -> Literal["agent", "__end__"]: | |
| """Determine if conversation should continue.""" | |
| if state.get("finished", False): | |
| logging.info("from human GOTO End node") | |
| return END | |
| last_msg = state["messages"][-1] | |
| if isinstance(last_msg, AIMessage): | |
| logging.info("Chatbot response obtained, ending conversation") | |
| return END | |
| else: | |
| logging.info("from human GOTO agent node") | |
| return "agent" | |
| # Prepare tools | |
| auto_tools = [] | |
| tool_node = ToolNode(auto_tools) | |
| interactive_tools = [wikipedia_search] | |
| # Bind all tools to the LLM | |
| llm_with_tools = llm.bind_tools(auto_tools + interactive_tools) | |
| # Build the graph | |
| graph_builder = StateGraph(OrderState) | |
| # Add nodes | |
| graph_builder.add_node("agent", agent_node) | |
| graph_builder.add_node("human", human_node) | |
| graph_builder.add_node("tools", tool_node) | |
| graph_builder.add_node("interactive_tools", interactive_tools_node) | |
| # Add edges and routing | |
| graph_builder.add_conditional_edges("agent", maybe_route_to_tools) | |
| graph_builder.add_conditional_edges("human", maybe_exit_human_node) | |
| graph_builder.add_edge("tools", "agent") | |
| graph_builder.add_edge("interactive_tools", "agent") | |
| graph_builder.add_edge(START, "human") | |
| # Compile the graph | |
| chat_graph = graph_builder.compile() | |
| def convert_history_to_messages(history: list) -> list[BaseMessage]: | |
| """ | |
| Convert Gradio chat history to a list of Langchain messages. | |
| Args: | |
| - history: Gradio's chat history format | |
| Returns: | |
| - List of Langchain BaseMessage objects | |
| """ | |
| messages = [] | |
| for human, ai in history: | |
| if human: | |
| messages.append(HumanMessage(content=human)) | |
| if ai: | |
| messages.append(AIMessage(content=ai)) | |
| return messages | |
| def gradio_chat(message: str, history: list) -> str: | |
| """ | |
| Gradio-compatible chat function that manages the conversation state. | |
| Args: | |
| - message: User's input message | |
| - history: Gradio's chat history | |
| Returns: | |
| - Bot's response as a string | |
| """ | |
| logging.info(f"{len(history)} history so far: {history}") | |
| # Ensure non-empty message | |
| if not message or message.strip() == "": | |
| message = "Hello, how can I help you today?" | |
| # Convert history to Langchain messages | |
| conversation_messages = [] | |
| for old_message in history: | |
| if old_message["content"].strip(): | |
| if old_message["role"] == "user": | |
| conversation_messages.append(HumanMessage(content=old_message["content"])) | |
| if old_message["role"] == "assistant": | |
| conversation_messages.append(AIMessage(content=old_message["content"])) | |
| # Add current message | |
| conversation_messages.append(HumanMessage(content=message)) | |
| # Create initial state with conversation history | |
| conversation_state = { | |
| "messages": conversation_messages, | |
| "order": [], | |
| "finished": False | |
| } | |
| logging.info(f"Conversation so far: {str(conversation_state)}") | |
| try: | |
| # Process the conversation through the graph | |
| conversation_state = chat_graph.invoke(conversation_state, {"recursion_limit": 10}) | |
| # Extract the latest bot message | |
| latest_message = conversation_state["messages"][-1] | |
| # Return the bot's response content | |
| logging.info(f"return: {latest_message.content}") | |
| return latest_message.content | |
| except Exception as e: | |
| return f"An error occurred: {str(e)}" | |
| # Gradio interface | |
| def launch_agent(): | |
| gr.ChatInterface( | |
| gradio_chat, | |
| type="messages", | |
| title="Agent", | |
| description="An AI agent (work in progress)", | |
| theme="ocean" | |
| ).launch() | |
| if __name__ == "__main__": | |
| # initiate logging tool | |
| logging.basicConfig( | |
| stream=sys.stdout, | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s') | |
| launch_agent() |