wt002 commited on
Commit
23ba2f5
·
verified ·
1 Parent(s): 4aededf

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +81 -32
agent.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import tools_condition
7
- from langgraph.prebuilt import ToolNode
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
  from langchain_groq import ChatGroq
10
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
@@ -839,44 +839,74 @@ def process_question(question: str):
839
 
840
 
841
 
 
842
  def call_llm(state):
843
  messages = state["messages"]
844
  response = llm.invoke(messages)
845
  return {"messages": messages + [response]}
846
 
847
- tool_node = ToolNode(tools)
 
 
 
 
848
 
 
 
849
 
850
- from langgraph.graph import StateGraph
851
- from typing import TypedDict
852
 
853
- # Define the state schema
854
- class AgentState(TypedDict):
855
- input: str
856
- result: str
857
 
858
- # Create the state graph builder with state_schema
859
- builder = StateGraph(AgentState)
860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
861
 
862
- builder.add_node("call_llm", call_llm)
863
- builder.add_node("call_tool", tool_node)
864
 
865
  # Decide what to do next: if tool call → call_tool, else → end
866
- def should_call_tool(state):
867
  last_msg = state["messages"][-1]
868
  if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
869
  return "call_tool"
870
  return "end"
871
 
872
- builder.set_entry_point("call_llm")
873
- builder.add_conditional_edges("call_llm", should_call_tool, {
874
- "call_tool": "call_tool",
875
- "end": None
876
- })
 
 
 
 
 
 
877
 
878
- # After tool runs, go back to the LLM
879
- builder.add_edge("call_tool", "call_llm")
880
 
881
 
882
 
@@ -1009,7 +1039,7 @@ model_config = {
1009
 
1010
  def build_graph(provider, model_config):
1011
  from langchain_core.messages import SystemMessage, HumanMessage
1012
- from langgraph.graph import StateGraph, ToolNode
1013
  from langchain_core.runnables import RunnableLambda
1014
  from some_module import vector_store # Make sure this is defined/imported
1015
 
@@ -1078,20 +1108,39 @@ def build_graph(provider, model_config):
1078
  else:
1079
  return "END"
1080
 
1081
- # Step 5: Define LangGraph StateGraph
1082
- builder = StateGraph(dict) # Using dict as state type here
1083
 
1084
- builder.add_node("retriever", retriever)
1085
- builder.add_node("assistant", assistant)
1086
- builder.add_node("tools", ToolNode(tools))
1087
 
1088
- builder.set_entry_point("retriever")
1089
- builder.add_edge("retriever", "assistant")
1090
- builder.add_conditional_edges("assistant", tools_condition)
1091
- builder.add_edge("tools", "assistant")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1092
 
1093
- graph = builder.compile()
1094
- return graph
1095
 
1096
 
1097
  # call build_graph AFTER it’s defined
 
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import tools_condition
7
+ #from langgraph.prebuilt import ToolNode
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
  from langchain_groq import ChatGroq
10
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
 
839
 
840
 
841
 
842
+
843
  def call_llm(state):
844
  messages = state["messages"]
845
  response = llm.invoke(messages)
846
  return {"messages": messages + [response]}
847
 
848
+ builder.set_entry_point("call_llm")
849
+ builder.add_conditional_edges("call_llm", should_call_tool, {
850
+ "call_tool": "call_tool",
851
+ "end": None
852
+ })
853
 
854
+ # After tool runs, go back to the LLM
855
+ builder.add_edge("call_tool", "call_llm")
856
 
 
 
857
 
 
 
 
 
858
 
859
+ from langchain.schema import AIMessage
 
860
 
861
+ def tool_dispatcher(state: AgentState) -> AgentState:
862
+ last_msg = state["messages"][-1]
863
+
864
+ # Make sure it's an AI message with tool_calls
865
+ if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
866
+ tool_call = last_msg.tool_calls[0]
867
+ tool_name = tool_call["name"]
868
+ tool_input = tool_call["args"] # Adjust based on your actual schema
869
+
870
+ tool_func = tool_map.get(tool_name, default_tool)
871
+
872
+ # If args is a dict and your tool expects unpacked values:
873
+ if isinstance(tool_input, dict):
874
+ result = tool_func.invoke(tool_input) if hasattr(tool_func, "invoke") else tool_func(**tool_input)
875
+ else:
876
+ result = tool_func.invoke(tool_input) if hasattr(tool_func, "invoke") else tool_func(tool_input)
877
+
878
+ # You can choose to append this to messages, or just save result
879
+ return {
880
+ **state,
881
+ "result": result,
882
+ # Optionally add: "messages": state["messages"] + [ToolMessage(...)]
883
+ }
884
+
885
+ # No tool call detected, return state unchanged
886
+ return state
887
+
888
+
889
 
 
 
890
 
891
  # Decide what to do next: if tool call → call_tool, else → end
892
+ def call_tool(state):
893
  last_msg = state["messages"][-1]
894
  if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
895
  return "call_tool"
896
  return "end"
897
 
898
+ from typing import TypedDict, List, Optional, Union
899
+ from langchain.schema import BaseMessage
900
+
901
+ class AgentState(TypedDict):
902
+ messages: List[BaseMessage] # chat history
903
+ input: str # original input
904
+ intent: str # derived or predicted intent
905
+ result: Optional[str] # tool output, if any
906
+
907
+ builder.add_node("call_tool", tool_dispatcher)
908
+
909
 
 
 
910
 
911
 
912
 
 
1039
 
1040
  def build_graph(provider, model_config):
1041
  from langchain_core.messages import SystemMessage, HumanMessage
1042
+ #from langgraph.graph import StateGraph, ToolNode
1043
  from langchain_core.runnables import RunnableLambda
1044
  from some_module import vector_store # Make sure this is defined/imported
1045
 
 
1108
  else:
1109
  return "END"
1110
 
 
 
1111
 
 
 
 
1112
 
1113
+ from langgraph.graph import StateGraph
1114
+
1115
+ # Build graph using AgentState as the shared schema
1116
+ builder = StateGraph(AgentState)
1117
+
1118
+ # Add nodes
1119
+ builder.add_node("retriever", retriever)
1120
+ builder.add_node("assistant", assistant)
1121
+ builder.add_node("call_llm", call_llm)
1122
+ builder.add_node("call_tool", tool_dispatcher) # one name is enough
1123
+
1124
+ # Entry point
1125
+ builder.set_entry_point("retriever")
1126
+
1127
+ # Define the flow
1128
+ builder.add_edge("retriever", "assistant")
1129
+ builder.add_edge("assistant", "call_llm")
1130
+
1131
+ # Conditional edge from LLM to tool or end
1132
+ builder.add_conditional_edges("call_llm", should_call_tool, {
1133
+ "call_tool": "call_tool",
1134
+ "end": None
1135
+ })
1136
+
1137
+ # Loop back after tool execution
1138
+ builder.add_edge("call_tool", "call_llm")
1139
+
1140
+ # Compile
1141
+ graph = builder.compile()
1142
+ return graph
1143
 
 
 
1144
 
1145
 
1146
  # call build_graph AFTER it’s defined