Spaces:
Build error
Build error
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -4,7 +4,7 @@ import os
|
|
| 4 |
from dotenv import load_dotenv
|
| 5 |
from langgraph.graph import START, StateGraph, MessagesState
|
| 6 |
from langgraph.prebuilt import tools_condition
|
| 7 |
-
from langgraph.prebuilt import ToolNode
|
| 8 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 9 |
from langchain_groq import ChatGroq
|
| 10 |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
|
|
@@ -839,44 +839,74 @@ def process_question(question: str):
|
|
| 839 |
|
| 840 |
|
| 841 |
|
|
|
|
| 842 |
def call_llm(state):
|
| 843 |
messages = state["messages"]
|
| 844 |
response = llm.invoke(messages)
|
| 845 |
return {"messages": messages + [response]}
|
| 846 |
|
| 847 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 848 |
|
|
|
|
|
|
|
| 849 |
|
| 850 |
-
from langgraph.graph import StateGraph
|
| 851 |
-
from typing import TypedDict
|
| 852 |
|
| 853 |
-
# Define the state schema
|
| 854 |
-
class AgentState(TypedDict):
|
| 855 |
-
input: str
|
| 856 |
-
result: str
|
| 857 |
|
| 858 |
-
|
| 859 |
-
builder = StateGraph(AgentState)
|
| 860 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 861 |
|
| 862 |
-
builder.add_node("call_llm", call_llm)
|
| 863 |
-
builder.add_node("call_tool", tool_node)
|
| 864 |
|
| 865 |
# Decide what to do next: if tool call → call_tool, else → end
|
| 866 |
-
def
|
| 867 |
last_msg = state["messages"][-1]
|
| 868 |
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
| 869 |
return "call_tool"
|
| 870 |
return "end"
|
| 871 |
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 877 |
|
| 878 |
-
# After tool runs, go back to the LLM
|
| 879 |
-
builder.add_edge("call_tool", "call_llm")
|
| 880 |
|
| 881 |
|
| 882 |
|
|
@@ -1009,7 +1039,7 @@ model_config = {
|
|
| 1009 |
|
| 1010 |
def build_graph(provider, model_config):
|
| 1011 |
from langchain_core.messages import SystemMessage, HumanMessage
|
| 1012 |
-
from langgraph.graph import StateGraph, ToolNode
|
| 1013 |
from langchain_core.runnables import RunnableLambda
|
| 1014 |
from some_module import vector_store # Make sure this is defined/imported
|
| 1015 |
|
|
@@ -1078,20 +1108,39 @@ def build_graph(provider, model_config):
|
|
| 1078 |
else:
|
| 1079 |
return "END"
|
| 1080 |
|
| 1081 |
-
# Step 5: Define LangGraph StateGraph
|
| 1082 |
-
builder = StateGraph(dict) # Using dict as state type here
|
| 1083 |
|
| 1084 |
-
builder.add_node("retriever", retriever)
|
| 1085 |
-
builder.add_node("assistant", assistant)
|
| 1086 |
-
builder.add_node("tools", ToolNode(tools))
|
| 1087 |
|
| 1088 |
-
|
| 1089 |
-
|
| 1090 |
-
|
| 1091 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1092 |
|
| 1093 |
-
graph = builder.compile()
|
| 1094 |
-
return graph
|
| 1095 |
|
| 1096 |
|
| 1097 |
# call build_graph AFTER it’s defined
|
|
|
|
| 4 |
from dotenv import load_dotenv
|
| 5 |
from langgraph.graph import START, StateGraph, MessagesState
|
| 6 |
from langgraph.prebuilt import tools_condition
|
| 7 |
+
#from langgraph.prebuilt import ToolNode
|
| 8 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 9 |
from langchain_groq import ChatGroq
|
| 10 |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
|
|
|
|
| 839 |
|
| 840 |
|
| 841 |
|
| 842 |
+
|
| 843 |
def call_llm(state):
|
| 844 |
messages = state["messages"]
|
| 845 |
response = llm.invoke(messages)
|
| 846 |
return {"messages": messages + [response]}
|
| 847 |
|
| 848 |
+
builder.set_entry_point("call_llm")
|
| 849 |
+
builder.add_conditional_edges("call_llm", should_call_tool, {
|
| 850 |
+
"call_tool": "call_tool",
|
| 851 |
+
"end": None
|
| 852 |
+
})
|
| 853 |
|
| 854 |
+
# After tool runs, go back to the LLM
|
| 855 |
+
builder.add_edge("call_tool", "call_llm")
|
| 856 |
|
|
|
|
|
|
|
| 857 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 858 |
|
| 859 |
+
from langchain.schema import AIMessage
|
|
|
|
| 860 |
|
| 861 |
+
def tool_dispatcher(state: AgentState) -> AgentState:
|
| 862 |
+
last_msg = state["messages"][-1]
|
| 863 |
+
|
| 864 |
+
# Make sure it's an AI message with tool_calls
|
| 865 |
+
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
| 866 |
+
tool_call = last_msg.tool_calls[0]
|
| 867 |
+
tool_name = tool_call["name"]
|
| 868 |
+
tool_input = tool_call["args"] # Adjust based on your actual schema
|
| 869 |
+
|
| 870 |
+
tool_func = tool_map.get(tool_name, default_tool)
|
| 871 |
+
|
| 872 |
+
# If args is a dict and your tool expects unpacked values:
|
| 873 |
+
if isinstance(tool_input, dict):
|
| 874 |
+
result = tool_func.invoke(tool_input) if hasattr(tool_func, "invoke") else tool_func(**tool_input)
|
| 875 |
+
else:
|
| 876 |
+
result = tool_func.invoke(tool_input) if hasattr(tool_func, "invoke") else tool_func(tool_input)
|
| 877 |
+
|
| 878 |
+
# You can choose to append this to messages, or just save result
|
| 879 |
+
return {
|
| 880 |
+
**state,
|
| 881 |
+
"result": result,
|
| 882 |
+
# Optionally add: "messages": state["messages"] + [ToolMessage(...)]
|
| 883 |
+
}
|
| 884 |
+
|
| 885 |
+
# No tool call detected, return state unchanged
|
| 886 |
+
return state
|
| 887 |
+
|
| 888 |
+
|
| 889 |
|
|
|
|
|
|
|
| 890 |
|
| 891 |
# Decide what to do next: if tool call → call_tool, else → end
|
| 892 |
+
def call_tool(state):
|
| 893 |
last_msg = state["messages"][-1]
|
| 894 |
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
| 895 |
return "call_tool"
|
| 896 |
return "end"
|
| 897 |
|
| 898 |
+
from typing import TypedDict, List, Optional, Union
|
| 899 |
+
from langchain.schema import BaseMessage
|
| 900 |
+
|
| 901 |
+
class AgentState(TypedDict):
|
| 902 |
+
messages: List[BaseMessage] # chat history
|
| 903 |
+
input: str # original input
|
| 904 |
+
intent: str # derived or predicted intent
|
| 905 |
+
result: Optional[str] # tool output, if any
|
| 906 |
+
|
| 907 |
+
builder.add_node("call_tool", tool_dispatcher)
|
| 908 |
+
|
| 909 |
|
|
|
|
|
|
|
| 910 |
|
| 911 |
|
| 912 |
|
|
|
|
| 1039 |
|
| 1040 |
def build_graph(provider, model_config):
|
| 1041 |
from langchain_core.messages import SystemMessage, HumanMessage
|
| 1042 |
+
#from langgraph.graph import StateGraph, ToolNode
|
| 1043 |
from langchain_core.runnables import RunnableLambda
|
| 1044 |
from some_module import vector_store # Make sure this is defined/imported
|
| 1045 |
|
|
|
|
| 1108 |
else:
|
| 1109 |
return "END"
|
| 1110 |
|
|
|
|
|
|
|
| 1111 |
|
|
|
|
|
|
|
|
|
|
| 1112 |
|
| 1113 |
+
from langgraph.graph import StateGraph
|
| 1114 |
+
|
| 1115 |
+
# Build graph using AgentState as the shared schema
|
| 1116 |
+
builder = StateGraph(AgentState)
|
| 1117 |
+
|
| 1118 |
+
# Add nodes
|
| 1119 |
+
builder.add_node("retriever", retriever)
|
| 1120 |
+
builder.add_node("assistant", assistant)
|
| 1121 |
+
builder.add_node("call_llm", call_llm)
|
| 1122 |
+
builder.add_node("call_tool", tool_dispatcher) # one name is enough
|
| 1123 |
+
|
| 1124 |
+
# Entry point
|
| 1125 |
+
builder.set_entry_point("retriever")
|
| 1126 |
+
|
| 1127 |
+
# Define the flow
|
| 1128 |
+
builder.add_edge("retriever", "assistant")
|
| 1129 |
+
builder.add_edge("assistant", "call_llm")
|
| 1130 |
+
|
| 1131 |
+
# Conditional edge from LLM to tool or end
|
| 1132 |
+
builder.add_conditional_edges("call_llm", should_call_tool, {
|
| 1133 |
+
"call_tool": "call_tool",
|
| 1134 |
+
"end": None
|
| 1135 |
+
})
|
| 1136 |
+
|
| 1137 |
+
# Loop back after tool execution
|
| 1138 |
+
builder.add_edge("call_tool", "call_llm")
|
| 1139 |
+
|
| 1140 |
+
# Compile
|
| 1141 |
+
graph = builder.compile()
|
| 1142 |
+
return graph
|
| 1143 |
|
|
|
|
|
|
|
| 1144 |
|
| 1145 |
|
| 1146 |
# call build_graph AFTER it’s defined
|