from dotenv import load_dotenv from typing import TypedDict, List, Dict, Any, Optional, Annotated from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings from langchain_google_genai import ChatGoogleGenerativeAI # Added ChatGoogleGenerativeAI from langchain_groq import ChatGroq from langgraph.graph import StateGraph, MessagesState, START, END from langgraph.graph.message import add_messages from langchain_core.messages import SystemMessage, HumanMessage, AnyMessage, AIMessage from langchain_core.messages.ai import subtract_usage from langchain.tools import Tool from langchain_core.tools import tool from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.document_loaders import WikipediaLoader from langchain_community.document_loaders import ArxivLoader from langchain_community.retrievers import BM25Retriever from langgraph.prebuilt import ToolNode, tools_condition from prompts import system_prompt # load environment variables load_dotenv() # Helper function to extract Arxiv URL def get_arxiv_url(content: str) -> str: """Extract arXiv ID from text content and format as a URL.""" lines = content.split('\n') for line in lines: if line.strip().startswith('arXiv:'): parts = line.strip().split() if parts: arxiv_id_with_prefix = parts[0] # e.g., 'arXiv:2302.00001v1' # Remove 'arXiv:' prefix arxiv_id = arxiv_id_with_prefix.replace('arXiv:', '').strip() # The standard URL format is https://arxiv.org/abs/YYYY.NNNNN # Extract just the base ID before any version indicator 'v' base_arxiv_id = arxiv_id.split('v')[0] return f"https://arxiv.org/abs/{base_arxiv_id}" return "unknown" # Fallback if ID is not found # wikipedia search tool @tool def search_wiki(query: str) -> Dict[str, str]: """Search Wikipedia for a query and return maximum 2 results. Args: query: The search query.""" print(f" executing search_wiki with query: {query}") # Added debug try: search_docs = WikipediaLoader(query=query, load_max_docs=2).load() print(f"Found {len(search_docs)} documents for query '{query}'") # Added debug formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content}\n' for doc in search_docs ] ) if not formatted_search_docs: print ("Empty search results") # Added debug return {"wiki_results": formatted_search_docs} except Exception as e: print(f"Error in search_wiki: {e}") # Added debug return {"wiki_results": f"Search error: {str(e)}"} # internet search tool @tool def search_web(query: str) -> Dict[str, str]: """Search Tavily for a query and return maximum 3 results. Args: query: The search query.""" print(f" executing search_web with query: {query}") # Added debug # Use run() instead of invoke() for tool execution try: search_docs = TavilySearchResults(max_results=3).run(query) print(f"DEBUG: search_docs type: {type(search_docs)}") # Added debug print(f"DEBUG: search_docs content: {search_docs}") # Added debug print(f"Found {len(search_docs)} documents for query '{query}'") # Formatted search results formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.get("content", "")}\n' for doc in search_docs ] ) if not formatted_search_docs: print ("Empty search results") return {"web_results": formatted_search_docs} except Exception as e: print(f"Error in search_web: {e}") return {"web_results": f"Search error: {str(e)}"} # ArXiv search tool @tool def search_arxiv(query: str) -> Dict[str, str]: """Search Arxiv for a query and return maximum 3 result. Args: query: The search query.""" print(f" executing search_arxiv with query: {query}") # Added debug try: search_docs = ArxivLoader(query=query, load_max_docs=3).load() print(f"DEBUG: search_docs type: {type(search_docs)}") # Added debug print(f"DEBUG: search_docs content: {search_docs}") # Added debug print(f"Found {len(search_docs)} documents for query '{query}'") # Added debug formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content[:1000]}\n' for doc in search_docs ] ) if not formatted_search_docs: print ("Empty search results") # Added debug return {"arxiv_results": formatted_search_docs} except Exception as e: print(f"Error in search_arxiv: {e}") # Added debug return {"arxiv_results": f"Search error: {str(e)}"} # build retriever # bm25_retriever = BM25Retriever.from_documents(docs) # init system message sys_msg = SystemMessage(content=system_prompt) tools = [ search_web, search_wiki, search_arxiv ] # build graph function def build_graph(): # llm llm = ChatGoogleGenerativeAI( model="gemini-2.5-flash-preview-04-17", temperature=0 ) print(f"DEBUG: llm object = {llm}") # bind tools to llm llm_with_tools = llm.bind_tools(tools) print(f"DEBUG: llm_with_tools object = {llm_with_tools}") # generate AgentState and Agent graph class AgentState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] def assistant(state: AgentState): result = llm_with_tools.invoke(state["messages"]) print(f"DEBUG: LLM result = {result}") # Ensure the result is always wrapped in a list, even if invoke returns a single message # Add usage information if it's not already present if isinstance(result, AIMessage) and result.usage_metadata is None: # Add dummy usage metadata if none exists result.usage_metadata = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} return { "messages": [result] } # build graph builder = StateGraph(AgentState) # define nodes builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(tools)) # define edges builder.add_edge(START, "assistant") builder.add_conditional_edges( "assistant", tools_condition, { # If the latest message requires a tool, route to tools "tools": "tools", # Otherwise, provide a direct response END: END, } ) builder.add_edge("tools", "assistant") return builder.compile() if __name__ == "__main__": # Test query for search_arxiv tool question = "latest research on quantum computing" graph = build_graph() messages = [HumanMessage(content=question)] print(f"Running graph with question: {question}") # Added debug print messages = graph.invoke({"messages": messages}) print("Graph execution finished. Messages:") # Added debug print for m in messages["messages"]: m.pretty_print()