Spaces:
Sleeping
Sleeping
import os | |
from typing import TypedDict, List, Dict, Any, Optional | |
from langgraph.graph import StateGraph, START, END | |
from langchain_openai import ChatOpenAI | |
from langchain_core.messages import HumanMessage, AIMessage | |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
from langgraph.prebuilt import ToolNode, tools_condition | |
from langchain_core.messages import HumanMessage, SystemMessage | |
from langchain_core.utils.function_calling import convert_to_openai_tool | |
from langchain.tools import Tool | |
from serpapi import GoogleSearch | |
import requests | |
from bs4 import BeautifulSoup | |
SERPAPI_API_KEY = os.environ["SERPAPI_TOKEN"] | |
def serpapi_search(query: str) -> str: | |
print(f"Running SerpAPI search for: {query}") | |
params = { | |
"engine": "google", | |
"q": query, | |
"api_key": SERPAPI_API_KEY, | |
"num": 3, | |
} | |
search = GoogleSearch(params) | |
results = search.get_dict() | |
if "organic_results" in results: | |
snippets = [] | |
for item in results["organic_results"]: | |
snippet = item.get("snippet", "") | |
link = item.get("link", "") | |
snippets.append(f"{snippet}\nURL: {link}") | |
return "\n\n".join(snippets) | |
return "No results found." | |
serpapi_tool = Tool( | |
name="serpapi_search", | |
func=serpapi_search, | |
description="A tool that allows you to search the web using Google via SerpAPI. Input should be a search query." | |
) | |
def fetch_website_content(url: str) -> str: | |
print(f"Fetching website content from: {url}") | |
try: | |
response = requests.get(url, timeout=5) | |
response.raise_for_status() | |
soup = BeautifulSoup(response.text, "html.parser") | |
# Get main text content (very basic) | |
text = soup.get_text(separator="\n", strip=True) | |
return text[:1000] # Return first 1000 chars for brevity | |
except Exception as e: | |
print(f"Error fetching website: {e}") | |
return f"Error fetching website: {e}" | |
fetch_website_tool = Tool( | |
name="fetch_website_content", | |
func=fetch_website_content, | |
description="Fetches and returns the main text content of a given website URL." | |
) | |
# Initialize LLM | |
model = ChatOpenAI( model="gpt-4o",temperature=0) | |
#model = ChatOpenAI(model="gpt-4o-mini", temperature=0) | |
#vision_llm = ChatOpenAI(model="gpt-4o") | |
#search_tool = DuckDuckGoSearchRun() | |
tools = [serpapi_tool]#, fetch_website_tool] | |
llm_with_tools = model.bind_tools(tools, parallel_tool_calls=False) | |
class AgentState(TypedDict): | |
question: Dict[str, Any] | |
messages: List[Any] | |
answer: Optional[str] | |
tool_calls: Optional[list] | |
tool_outputs: Optional[list] | |
def assistant(state: AgentState): | |
print("\n--- ASSISTANT NODE ---") | |
print(f"State received: {state}") | |
question = state["question"] | |
print(f"Question dict: {question}") | |
#textual_description_of_tool = """ | |
#search_tool: A tool that allows you to search the web using DuckDuckGo. It returns a list of search results based on the query provided. | |
#""" | |
textual_description_of_tool = """ | |
serpapi_search: A tool that allows you to search the web using Google via SerpAPI. It returns a list of search results based on the query provided. | |
fetch_website_content(url: str) -> str: A tool that fetches and returns the main text content of a given website URL. | |
""" | |
system_prompt = SystemMessage( | |
content=f""" | |
Your answers are tested. Try to answer the question as accurately as possible. Give only the minimum necessary information to answer the question. | |
If you use a tool, answer the question using the tool results provided below. | |
Tool results will be provided as context after your question. If you receive a tool output, then use this information and come to the final answer if possible. | |
Only call another tool if you cannot answer the question with the information provided. | |
If you formulate your final answer, analyze it if it really ONLY answers the question. Don't provide additional information. One word, number or name is enough if it answers the question. | |
""" | |
#You can use the following tools to help you: | |
#{textual_description_of_tool} | |
) | |
messages = [system_prompt] | |
# Always add the user question | |
messages.append(HumanMessage(content=f"Question: {question.get('question', question)}")) | |
# If tool_outputs exist, add them as context | |
if state.get("tool_outputs"): | |
# Format tool results as plain text | |
tool_results = state["tool_outputs"] | |
if isinstance(tool_results, dict): | |
tool_text = "" | |
if "search_results" in tool_results and tool_results["search_results"]: | |
tool_text += "Search Results:\n" | |
tool_text += "\n".join(str(r) for r in tool_results["search_results"]) | |
if "website_contents" in tool_results and tool_results["website_contents"]: | |
tool_text += "\nWebsite Contents:\n" | |
for wc in tool_results["website_contents"]: | |
tool_text += f"\nURL: {wc['url']}\nContent: {wc['content']}\n" | |
else: | |
tool_text = str(tool_results) | |
messages.append(HumanMessage(content=f"Tool results:\n{tool_text}")) | |
print(f"Messages sent to LLM: {messages}") | |
response = llm_with_tools.invoke(messages) | |
print(f"Raw LLM response: {response}") | |
# If the LLM wants to call a tool, store tool_calls in state | |
tool_calls = getattr(response, "tool_calls", None) | |
if tool_calls: | |
print(f"Tool calls requested: {tool_calls}") | |
state["tool_calls"] = tool_calls | |
state["answer"] = "" # Not final yet | |
state.setdefault("messages", []).append(AIMessage(content="Calling tool: " + str(tool_calls))) | |
else: | |
state["answer"] = response.content.strip() | |
print(f"Model response: {state['answer']}") | |
state.setdefault("messages", []).append(AIMessage(content=state["answer"])) | |
state["tool_calls"] = None | |
return state | |
def tool_node(state: AgentState): | |
print("\n--- TOOL NODE ---") | |
print(f"State received: {state}") | |
search_results = [] | |
website_contents = [] | |
tool_calls = state.get("tool_calls") or [] | |
for call in tool_calls: | |
print(f"Tool call: {call}") | |
args = call.get("args", {}) | |
# Accept both {"query": ...} and {"__arg1": ...} | |
query = args.get("query") or args.get("__arg1") or (list(args.values())[0] if args else None) | |
print(f"Query to use: {query}") | |
if call["name"] == "serpapi_search": | |
print("--- SERPAPI SEARCH ---") | |
try: | |
result = serpapi_search(query) | |
search_results.append(result) | |
except Exception as e: | |
print(f"Error running SerpAPI search: {e}") | |
search_results.append(f"Error: {e}") | |
elif call["name"] == "fetch_website_content": | |
print("--- FETCH WEBSITE CONTENT ---") | |
try: | |
content = fetch_website_content(query) | |
website_contents.append({"url": query, "content": content}) | |
except Exception as e: | |
print(f"Error fetching website: {e}") | |
website_contents.append({"url": query, "content": f"Error: {e}"}) | |
# Store tool outputs in state for the assistant | |
state["tool_outputs"] = { | |
"search_results": search_results, | |
"website_contents": website_contents | |
} | |
state["tool_calls"] = None # Clear tool calls | |
# Add tool results to conversation history for traceability | |
state.setdefault("messages", []).append( | |
HumanMessage(content=f"Tool results: {state['tool_outputs']}") | |
) | |
return state | |
class BasicAgent: | |
compiled_graph: StateGraph | |
def __init__(self): | |
print("BasicAgent initialized.") | |
#building the graph | |
answering_graph = StateGraph(AgentState) | |
# Add nodes | |
answering_graph.add_node("assistant", assistant) | |
#answering_graph.add_node("tools", ToolNode(tools)) | |
answering_graph.add_node("tools", tool_node) | |
# Add edges | |
answering_graph.add_edge(START, "assistant") | |
answering_graph.add_conditional_edges( | |
"assistant", | |
lambda state: "tools" if state.get("tool_calls") else END | |
) | |
answering_graph.add_edge("tools", "assistant") | |
# Compile the graph | |
self.compiled_graph = answering_graph.compile() | |
def __call__(self, question: str) -> str: | |
question_text = question.get("question") | |
print(f"Agent received question (first 50 chars): {question_text[:50]}...") | |
initial_state = { | |
"question": question, | |
"messages": [], | |
"answer": None, | |
"tool_calls": None, | |
"tool_outputs": None | |
} | |
print(f"Initial state: {initial_state}") | |
answer = self.compiled_graph.invoke(initial_state) | |
print(f"Agent returning answer: {answer.get('answer')}") | |
return answer.get("answer") |