wt002 commited on
Commit
a1d5445
Β·
verified Β·
1 Parent(s): 1d71e5f

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +32 -1
agent.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import tools_condition
7
- #from langgraph.prebuilt import ToolNode
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
  from langchain_groq import ChatGroq
10
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
@@ -839,6 +839,37 @@ def process_question(question: str):
839
 
840
 
841
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842
 
843
  # To store previously asked questions and timestamps (simulating state persistence)
844
  recent_questions = {}
 
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import tools_condition
7
+ from langgraph.prebuilt import ToolNode
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
  from langchain_groq import ChatGroq
10
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
 
839
 
840
 
841
 
842
+ def call_llm(state):
843
+ messages = state["messages"]
844
+ response = llm.invoke(messages)
845
+ return {"messages": messages + [response]}
846
+
847
+ tool_node = ToolNode([search])
848
+
849
+
850
+ builder = StateGraph()
851
+
852
+ builder.add_node("call_llm", call_llm)
853
+ builder.add_node("call_tool", tool_node)
854
+
855
+ # Decide what to do next: if tool call β†’ call_tool, else β†’ end
856
+ def should_call_tool(state):
857
+ last_msg = state["messages"][-1]
858
+ if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
859
+ return "call_tool"
860
+ return "end"
861
+
862
+ builder.set_entry_point("call_llm")
863
+ builder.add_conditional_edges("call_llm", should_call_tool, {
864
+ "call_tool": "call_tool",
865
+ "end": None
866
+ })
867
+
868
+ # After tool runs, go back to the LLM
869
+ builder.add_edge("call_tool", "call_llm")
870
+
871
+
872
+
873
 
874
  # To store previously asked questions and timestamps (simulating state persistence)
875
  recent_questions = {}