Spaces:
Runtime error
Runtime error
Update agent.py
Browse files
agent.py
CHANGED
@@ -509,25 +509,39 @@ def process_question(question):
|
|
509 |
|
510 |
|
511 |
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
519 |
filtered = [doc for doc, score in results if score < threshold]
|
520 |
|
521 |
-
# Provide a default message if no documents found
|
522 |
if not filtered:
|
523 |
-
|
524 |
else:
|
525 |
content = "\n\n".join(doc.page_content for doc in filtered)
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
|
530 |
-
return {"messages": [sys_msg] + state["messages"] + [example_msg]}
|
531 |
|
532 |
|
533 |
|
@@ -537,17 +551,29 @@ def retriever(state: MessagesState):
|
|
537 |
def get_llm(provider: str, config: dict):
|
538 |
if provider == "google":
|
539 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
540 |
-
return ChatGoogleGenerativeAI(
|
|
|
|
|
|
|
|
|
541 |
|
542 |
elif provider == "groq":
|
543 |
from langchain_groq import ChatGroq
|
544 |
-
return ChatGroq(
|
|
|
|
|
|
|
|
|
545 |
|
546 |
elif provider == "huggingface":
|
547 |
from langchain_huggingface import ChatHuggingFace
|
548 |
from langchain_huggingface import HuggingFaceEndpoint
|
549 |
return ChatHuggingFace(
|
550 |
-
llm=HuggingFaceEndpoint(
|
|
|
|
|
|
|
|
|
551 |
)
|
552 |
|
553 |
else:
|
@@ -607,7 +633,7 @@ def planner(question: str, tools: list) -> tuple:
|
|
607 |
return detected_intent, matched_tools if matched_tools else [tools[0]]
|
608 |
|
609 |
|
610 |
-
|
611 |
|
612 |
def task_classifier(question: str) -> str:
|
613 |
"""
|
@@ -889,7 +915,7 @@ def tool_dispatcher(state: AgentState) -> AgentState:
|
|
889 |
|
890 |
|
891 |
# Decide what to do next: if tool call → call_tool, else → end
|
892 |
-
def
|
893 |
last_msg = state["messages"][-1]
|
894 |
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
895 |
return "call_tool"
|
@@ -904,9 +930,6 @@ class AgentState(TypedDict):
|
|
904 |
intent: str # derived or predicted intent
|
905 |
result: Optional[str] # tool output, if any
|
906 |
|
907 |
-
builder.add_node("call_tool", tool_dispatcher)
|
908 |
-
|
909 |
-
|
910 |
|
911 |
|
912 |
|
@@ -1036,14 +1059,7 @@ model_config = {
|
|
1036 |
"huggingfacehub_api_token": os.getenv("HF_TOKEN")
|
1037 |
}
|
1038 |
|
1039 |
-
|
1040 |
-
def build_graph(provider, model_config):
|
1041 |
-
from langchain_core.messages import SystemMessage, HumanMessage
|
1042 |
-
#from langgraph.graph import StateGraph, ToolNode
|
1043 |
-
from langchain_core.runnables import RunnableLambda
|
1044 |
-
from some_module import vector_store # Make sure this is defined/imported
|
1045 |
-
|
1046 |
-
# Step 1: Get LLM
|
1047 |
def get_llm(provider: str, config: dict):
|
1048 |
if provider == "huggingface":
|
1049 |
from langchain_huggingface import HuggingFaceEndpoint
|
@@ -1056,7 +1072,30 @@ def build_graph(provider, model_config):
|
|
1056 |
)
|
1057 |
else:
|
1058 |
raise ValueError(f"Unsupported provider: {provider}")
|
1059 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1060 |
llm = get_llm(provider, model_config)
|
1061 |
|
1062 |
# Step 2: Define tools
|
@@ -1071,80 +1110,52 @@ def build_graph(provider, model_config):
|
|
1071 |
wikidata_query
|
1072 |
]
|
1073 |
|
|
|
|
|
|
|
1074 |
# Step 3: Bind tools to LLM
|
1075 |
llm_with_tools = llm.bind_tools(tools)
|
1076 |
|
1077 |
# Step 4: Build stateful graph logic
|
1078 |
sys_msg = SystemMessage(content="You are a helpful assistant.")
|
1079 |
|
1080 |
-
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
|
1085 |
-
wiki_result = wiki_search.run(user_query)
|
1086 |
-
return {
|
1087 |
-
"messages": [
|
1088 |
-
sys_msg,
|
1089 |
-
state["messages"][0],
|
1090 |
-
HumanMessage(content=f"Using Wikipedia search:\n\n{wiki_result}")
|
1091 |
-
]
|
1092 |
-
}
|
1093 |
-
else:
|
1094 |
-
return {
|
1095 |
-
"messages": [
|
1096 |
-
sys_msg,
|
1097 |
-
state["messages"][0],
|
1098 |
-
HumanMessage(content=f"Reference:\n\n{similar_docs[0].page_content}")
|
1099 |
-
]
|
1100 |
-
}
|
1101 |
-
|
1102 |
-
def assistant(state: dict):
|
1103 |
-
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
1104 |
-
|
1105 |
-
def tools_condition(state: dict) -> str:
|
1106 |
-
if "use tool" in state["messages"][-1].content.lower():
|
1107 |
-
return "tools"
|
1108 |
-
else:
|
1109 |
-
return "END"
|
1110 |
-
|
1111 |
|
|
|
|
|
|
|
|
|
1112 |
|
1113 |
-
|
1114 |
|
1115 |
-
# Build
|
1116 |
-
builder = StateGraph(AgentState)
|
1117 |
|
1118 |
-
|
1119 |
-
builder.add_node("
|
1120 |
-
builder.add_node("
|
1121 |
-
builder.add_node("
|
1122 |
-
builder.add_node("call_tool", tool_dispatcher) # one name is enough
|
1123 |
|
1124 |
-
|
1125 |
-
builder.
|
1126 |
-
|
1127 |
-
|
1128 |
-
|
1129 |
-
|
1130 |
-
|
1131 |
-
|
1132 |
-
builder.add_conditional_edges("call_llm", should_call_tool, {
|
1133 |
-
"call_tool": "call_tool",
|
1134 |
-
"end": None
|
1135 |
-
})
|
1136 |
-
|
1137 |
-
# Loop back after tool execution
|
1138 |
-
builder.add_edge("call_tool", "call_llm")
|
1139 |
|
1140 |
-
|
1141 |
-
graph
|
1142 |
-
return graph
|
1143 |
|
1144 |
|
1145 |
|
1146 |
# call build_graph AFTER it’s defined
|
1147 |
-
agent =
|
1148 |
|
1149 |
# Now you can use the agent like this:
|
1150 |
result = agent.invoke({"messages": [HumanMessage(content=question)]})
|
|
|
509 |
|
510 |
|
511 |
|
512 |
+
from langchain.schema import HumanMessage
|
513 |
+
|
514 |
+
def retriever(state: MessagesState, k: int = 4):
|
515 |
+
"""
|
516 |
+
Retrieves documents from the vector store using similarity scores,
|
517 |
+
applies a dynamic threshold filter, and returns updated message state.
|
518 |
+
|
519 |
+
Args:
|
520 |
+
state (MessagesState): Current message state including the user's query.
|
521 |
+
k (int): Number of top results to retrieve from the vector store.
|
522 |
+
|
523 |
+
Returns:
|
524 |
+
dict: Updated messages state including relevant documents or fallback message.
|
525 |
+
"""
|
526 |
+
query = state["messages"][0].content.strip()
|
527 |
+
results = vector_store.similarity_search_with_score(query, k=k)
|
528 |
+
|
529 |
+
# Determine dynamic similarity threshold
|
530 |
+
if any(keyword in query.lower() for keyword in ["who", "what", "where", "when", "why", "how"]):
|
531 |
+
threshold = 0.75
|
532 |
+
else:
|
533 |
+
threshold = 0.8
|
534 |
+
|
535 |
filtered = [doc for doc, score in results if score < threshold]
|
536 |
|
|
|
537 |
if not filtered:
|
538 |
+
response_msg = HumanMessage(content="No relevant documents found.")
|
539 |
else:
|
540 |
content = "\n\n".join(doc.page_content for doc in filtered)
|
541 |
+
response_msg = HumanMessage(content=f"Here are relevant reference documents:\n\n{content}")
|
542 |
+
|
543 |
+
return {"messages": [sys_msg] + state["messages"] + [response_msg]}
|
544 |
|
|
|
545 |
|
546 |
|
547 |
|
|
|
551 |
def get_llm(provider: str, config: dict):
|
552 |
if provider == "google":
|
553 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
554 |
+
return ChatGoogleGenerativeAI(
|
555 |
+
model=config.get("model"),
|
556 |
+
temperature=config.get("temperature", 0.7),
|
557 |
+
google_api_key=config.get("api_key") # Optional: if needed
|
558 |
+
)
|
559 |
|
560 |
elif provider == "groq":
|
561 |
from langchain_groq import ChatGroq
|
562 |
+
return ChatGroq(
|
563 |
+
model=config.get("model"),
|
564 |
+
temperature=config.get("temperature", 0.7),
|
565 |
+
groq_api_key=config.get("api_key") # Optional: if needed
|
566 |
+
)
|
567 |
|
568 |
elif provider == "huggingface":
|
569 |
from langchain_huggingface import ChatHuggingFace
|
570 |
from langchain_huggingface import HuggingFaceEndpoint
|
571 |
return ChatHuggingFace(
|
572 |
+
llm=HuggingFaceEndpoint(
|
573 |
+
endpoint_url=config.get("url"),
|
574 |
+
temperature=config.get("temperature", 0.7),
|
575 |
+
huggingfacehub_api_token=config.get("api_key") # Optional
|
576 |
+
)
|
577 |
)
|
578 |
|
579 |
else:
|
|
|
633 |
return detected_intent, matched_tools if matched_tools else [tools[0]]
|
634 |
|
635 |
|
636 |
+
|
637 |
|
638 |
def task_classifier(question: str) -> str:
|
639 |
"""
|
|
|
915 |
|
916 |
|
917 |
# Decide what to do next: if tool call → call_tool, else → end
|
918 |
+
def should_call_tool(state):
|
919 |
last_msg = state["messages"][-1]
|
920 |
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
921 |
return "call_tool"
|
|
|
930 |
intent: str # derived or predicted intent
|
931 |
result: Optional[str] # tool output, if any
|
932 |
|
|
|
|
|
|
|
933 |
|
934 |
|
935 |
|
|
|
1059 |
"huggingfacehub_api_token": os.getenv("HF_TOKEN")
|
1060 |
}
|
1061 |
|
1062 |
+
# Get LLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1063 |
def get_llm(provider: str, config: dict):
|
1064 |
if provider == "huggingface":
|
1065 |
from langchain_huggingface import HuggingFaceEndpoint
|
|
|
1072 |
)
|
1073 |
else:
|
1074 |
raise ValueError(f"Unsupported provider: {provider}")
|
1075 |
+
|
1076 |
+
|
1077 |
+
def assistant(state: dict):
|
1078 |
+
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
1079 |
+
|
1080 |
+
def tools_condition(state: dict) -> str:
|
1081 |
+
if "use tool" in state["messages"][-1].content.lower():
|
1082 |
+
return "tools"
|
1083 |
+
else:
|
1084 |
+
return "END"
|
1085 |
+
|
1086 |
+
|
1087 |
+
from langgraph.graph import StateGraph
|
1088 |
+
|
1089 |
+
# Build graph using AgentState as the shared schema
|
1090 |
+
def build_graph() -> StateGraph:
|
1091 |
+
builder = StateGraph(AgentState)
|
1092 |
+
|
1093 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
1094 |
+
#from langgraph.graph import StateGraph, ToolNode
|
1095 |
+
from langchain_core.runnables import RunnableLambda
|
1096 |
+
from some_module import vector_store # Make sure this is defined/imported
|
1097 |
+
|
1098 |
+
|
1099 |
llm = get_llm(provider, model_config)
|
1100 |
|
1101 |
# Step 2: Define tools
|
|
|
1110 |
wikidata_query
|
1111 |
]
|
1112 |
|
1113 |
+
global tool_map
|
1114 |
+
tool_map = {t.name: t for t in tools}
|
1115 |
+
|
1116 |
# Step 3: Bind tools to LLM
|
1117 |
llm_with_tools = llm.bind_tools(tools)
|
1118 |
|
1119 |
# Step 4: Build stateful graph logic
|
1120 |
sys_msg = SystemMessage(content="You are a helpful assistant.")
|
1121 |
|
1122 |
+
# --- Define nodes ---
|
1123 |
+
retriever = RunnableLambda(lambda state: {
|
1124 |
+
**state,
|
1125 |
+
"retrieved_docs": vector_store.similarity_search(state["input"])
|
1126 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1127 |
|
1128 |
+
assistant = RunnableLambda(lambda state: {
|
1129 |
+
**state,
|
1130 |
+
"messages": state["messages"] + [SystemMessage(content="You are a helpful assistant.")]
|
1131 |
+
})
|
1132 |
|
1133 |
+
call_llm = llm.bind_tools(tools)
|
1134 |
|
1135 |
+
# --- Build the graph ---
|
1136 |
+
builder = StateGraph(AgentState)
|
1137 |
|
1138 |
+
builder.add_node("retriever", retriever)
|
1139 |
+
builder.add_node("assistant", assistant)
|
1140 |
+
builder.add_node("call_llm", call_llm)
|
1141 |
+
builder.add_node("call_tool", tool_dispatcher)
|
|
|
1142 |
|
1143 |
+
builder.set_entry_point("retriever")
|
1144 |
+
builder.add_edge("retriever", "assistant")
|
1145 |
+
builder.add_edge("assistant", "call_llm")
|
1146 |
+
builder.add_conditional_edges("call_llm", should_call_tool, {
|
1147 |
+
"call_tool": "call_tool",
|
1148 |
+
"end": None
|
1149 |
+
})
|
1150 |
+
builder.add_edge("call_tool", "call_llm")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1151 |
|
1152 |
+
graph = builder.compile()
|
1153 |
+
return graph
|
|
|
1154 |
|
1155 |
|
1156 |
|
1157 |
# call build_graph AFTER it’s defined
|
1158 |
+
agent = graph
|
1159 |
|
1160 |
# Now you can use the agent like this:
|
1161 |
result = agent.invoke({"messages": [HumanMessage(content=question)]})
|