Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -509,25 +509,39 @@ def process_question(question):
|
|
| 509 |
|
| 510 |
|
| 511 |
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
filtered = [doc for doc, score in results if score < threshold]
|
| 520 |
|
| 521 |
-
# Provide a default message if no documents found
|
| 522 |
if not filtered:
|
| 523 |
-
|
| 524 |
else:
|
| 525 |
content = "\n\n".join(doc.page_content for doc in filtered)
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
|
| 530 |
-
return {"messages": [sys_msg] + state["messages"] + [example_msg]}
|
| 531 |
|
| 532 |
|
| 533 |
|
|
@@ -537,17 +551,29 @@ def retriever(state: MessagesState):
|
|
| 537 |
def get_llm(provider: str, config: dict):
|
| 538 |
if provider == "google":
|
| 539 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 540 |
-
return ChatGoogleGenerativeAI(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
|
| 542 |
elif provider == "groq":
|
| 543 |
from langchain_groq import ChatGroq
|
| 544 |
-
return ChatGroq(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
|
| 546 |
elif provider == "huggingface":
|
| 547 |
from langchain_huggingface import ChatHuggingFace
|
| 548 |
from langchain_huggingface import HuggingFaceEndpoint
|
| 549 |
return ChatHuggingFace(
|
| 550 |
-
llm=HuggingFaceEndpoint(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
)
|
| 552 |
|
| 553 |
else:
|
|
@@ -607,7 +633,7 @@ def planner(question: str, tools: list) -> tuple:
|
|
| 607 |
return detected_intent, matched_tools if matched_tools else [tools[0]]
|
| 608 |
|
| 609 |
|
| 610 |
-
|
| 611 |
|
| 612 |
def task_classifier(question: str) -> str:
|
| 613 |
"""
|
|
@@ -889,7 +915,7 @@ def tool_dispatcher(state: AgentState) -> AgentState:
|
|
| 889 |
|
| 890 |
|
| 891 |
# Decide what to do next: if tool call → call_tool, else → end
|
| 892 |
-
def
|
| 893 |
last_msg = state["messages"][-1]
|
| 894 |
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
| 895 |
return "call_tool"
|
|
@@ -904,9 +930,6 @@ class AgentState(TypedDict):
|
|
| 904 |
intent: str # derived or predicted intent
|
| 905 |
result: Optional[str] # tool output, if any
|
| 906 |
|
| 907 |
-
builder.add_node("call_tool", tool_dispatcher)
|
| 908 |
-
|
| 909 |
-
|
| 910 |
|
| 911 |
|
| 912 |
|
|
@@ -1036,14 +1059,7 @@ model_config = {
|
|
| 1036 |
"huggingfacehub_api_token": os.getenv("HF_TOKEN")
|
| 1037 |
}
|
| 1038 |
|
| 1039 |
-
|
| 1040 |
-
def build_graph(provider, model_config):
|
| 1041 |
-
from langchain_core.messages import SystemMessage, HumanMessage
|
| 1042 |
-
#from langgraph.graph import StateGraph, ToolNode
|
| 1043 |
-
from langchain_core.runnables import RunnableLambda
|
| 1044 |
-
from some_module import vector_store # Make sure this is defined/imported
|
| 1045 |
-
|
| 1046 |
-
# Step 1: Get LLM
|
| 1047 |
def get_llm(provider: str, config: dict):
|
| 1048 |
if provider == "huggingface":
|
| 1049 |
from langchain_huggingface import HuggingFaceEndpoint
|
|
@@ -1056,7 +1072,30 @@ def build_graph(provider, model_config):
|
|
| 1056 |
)
|
| 1057 |
else:
|
| 1058 |
raise ValueError(f"Unsupported provider: {provider}")
|
| 1059 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1060 |
llm = get_llm(provider, model_config)
|
| 1061 |
|
| 1062 |
# Step 2: Define tools
|
|
@@ -1071,80 +1110,52 @@ def build_graph(provider, model_config):
|
|
| 1071 |
wikidata_query
|
| 1072 |
]
|
| 1073 |
|
|
|
|
|
|
|
|
|
|
| 1074 |
# Step 3: Bind tools to LLM
|
| 1075 |
llm_with_tools = llm.bind_tools(tools)
|
| 1076 |
|
| 1077 |
# Step 4: Build stateful graph logic
|
| 1078 |
sys_msg = SystemMessage(content="You are a helpful assistant.")
|
| 1079 |
|
| 1080 |
-
|
| 1081 |
-
|
| 1082 |
-
|
| 1083 |
-
|
| 1084 |
-
|
| 1085 |
-
wiki_result = wiki_search.run(user_query)
|
| 1086 |
-
return {
|
| 1087 |
-
"messages": [
|
| 1088 |
-
sys_msg,
|
| 1089 |
-
state["messages"][0],
|
| 1090 |
-
HumanMessage(content=f"Using Wikipedia search:\n\n{wiki_result}")
|
| 1091 |
-
]
|
| 1092 |
-
}
|
| 1093 |
-
else:
|
| 1094 |
-
return {
|
| 1095 |
-
"messages": [
|
| 1096 |
-
sys_msg,
|
| 1097 |
-
state["messages"][0],
|
| 1098 |
-
HumanMessage(content=f"Reference:\n\n{similar_docs[0].page_content}")
|
| 1099 |
-
]
|
| 1100 |
-
}
|
| 1101 |
-
|
| 1102 |
-
def assistant(state: dict):
|
| 1103 |
-
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
| 1104 |
-
|
| 1105 |
-
def tools_condition(state: dict) -> str:
|
| 1106 |
-
if "use tool" in state["messages"][-1].content.lower():
|
| 1107 |
-
return "tools"
|
| 1108 |
-
else:
|
| 1109 |
-
return "END"
|
| 1110 |
-
|
| 1111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1112 |
|
| 1113 |
-
|
| 1114 |
|
| 1115 |
-
# Build
|
| 1116 |
-
builder = StateGraph(AgentState)
|
| 1117 |
|
| 1118 |
-
|
| 1119 |
-
builder.add_node("
|
| 1120 |
-
builder.add_node("
|
| 1121 |
-
builder.add_node("
|
| 1122 |
-
builder.add_node("call_tool", tool_dispatcher) # one name is enough
|
| 1123 |
|
| 1124 |
-
|
| 1125 |
-
builder.
|
| 1126 |
-
|
| 1127 |
-
|
| 1128 |
-
|
| 1129 |
-
|
| 1130 |
-
|
| 1131 |
-
|
| 1132 |
-
builder.add_conditional_edges("call_llm", should_call_tool, {
|
| 1133 |
-
"call_tool": "call_tool",
|
| 1134 |
-
"end": None
|
| 1135 |
-
})
|
| 1136 |
-
|
| 1137 |
-
# Loop back after tool execution
|
| 1138 |
-
builder.add_edge("call_tool", "call_llm")
|
| 1139 |
|
| 1140 |
-
|
| 1141 |
-
graph
|
| 1142 |
-
return graph
|
| 1143 |
|
| 1144 |
|
| 1145 |
|
| 1146 |
# call build_graph AFTER it’s defined
|
| 1147 |
-
agent =
|
| 1148 |
|
| 1149 |
# Now you can use the agent like this:
|
| 1150 |
result = agent.invoke({"messages": [HumanMessage(content=question)]})
|
|
|
|
| 509 |
|
| 510 |
|
| 511 |
|
| 512 |
+
from langchain.schema import HumanMessage
|
| 513 |
+
|
| 514 |
+
def retriever(state: MessagesState, k: int = 4):
|
| 515 |
+
"""
|
| 516 |
+
Retrieves documents from the vector store using similarity scores,
|
| 517 |
+
applies a dynamic threshold filter, and returns updated message state.
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
state (MessagesState): Current message state including the user's query.
|
| 521 |
+
k (int): Number of top results to retrieve from the vector store.
|
| 522 |
+
|
| 523 |
+
Returns:
|
| 524 |
+
dict: Updated messages state including relevant documents or fallback message.
|
| 525 |
+
"""
|
| 526 |
+
query = state["messages"][0].content.strip()
|
| 527 |
+
results = vector_store.similarity_search_with_score(query, k=k)
|
| 528 |
+
|
| 529 |
+
# Determine dynamic similarity threshold
|
| 530 |
+
if any(keyword in query.lower() for keyword in ["who", "what", "where", "when", "why", "how"]):
|
| 531 |
+
threshold = 0.75
|
| 532 |
+
else:
|
| 533 |
+
threshold = 0.8
|
| 534 |
+
|
| 535 |
filtered = [doc for doc, score in results if score < threshold]
|
| 536 |
|
|
|
|
| 537 |
if not filtered:
|
| 538 |
+
response_msg = HumanMessage(content="No relevant documents found.")
|
| 539 |
else:
|
| 540 |
content = "\n\n".join(doc.page_content for doc in filtered)
|
| 541 |
+
response_msg = HumanMessage(content=f"Here are relevant reference documents:\n\n{content}")
|
| 542 |
+
|
| 543 |
+
return {"messages": [sys_msg] + state["messages"] + [response_msg]}
|
| 544 |
|
|
|
|
| 545 |
|
| 546 |
|
| 547 |
|
|
|
|
| 551 |
def get_llm(provider: str, config: dict):
|
| 552 |
if provider == "google":
|
| 553 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 554 |
+
return ChatGoogleGenerativeAI(
|
| 555 |
+
model=config.get("model"),
|
| 556 |
+
temperature=config.get("temperature", 0.7),
|
| 557 |
+
google_api_key=config.get("api_key") # Optional: if needed
|
| 558 |
+
)
|
| 559 |
|
| 560 |
elif provider == "groq":
|
| 561 |
from langchain_groq import ChatGroq
|
| 562 |
+
return ChatGroq(
|
| 563 |
+
model=config.get("model"),
|
| 564 |
+
temperature=config.get("temperature", 0.7),
|
| 565 |
+
groq_api_key=config.get("api_key") # Optional: if needed
|
| 566 |
+
)
|
| 567 |
|
| 568 |
elif provider == "huggingface":
|
| 569 |
from langchain_huggingface import ChatHuggingFace
|
| 570 |
from langchain_huggingface import HuggingFaceEndpoint
|
| 571 |
return ChatHuggingFace(
|
| 572 |
+
llm=HuggingFaceEndpoint(
|
| 573 |
+
endpoint_url=config.get("url"),
|
| 574 |
+
temperature=config.get("temperature", 0.7),
|
| 575 |
+
huggingfacehub_api_token=config.get("api_key") # Optional
|
| 576 |
+
)
|
| 577 |
)
|
| 578 |
|
| 579 |
else:
|
|
|
|
| 633 |
return detected_intent, matched_tools if matched_tools else [tools[0]]
|
| 634 |
|
| 635 |
|
| 636 |
+
|
| 637 |
|
| 638 |
def task_classifier(question: str) -> str:
|
| 639 |
"""
|
|
|
|
| 915 |
|
| 916 |
|
| 917 |
# Decide what to do next: if tool call → call_tool, else → end
|
| 918 |
+
def should_call_tool(state):
|
| 919 |
last_msg = state["messages"][-1]
|
| 920 |
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
| 921 |
return "call_tool"
|
|
|
|
| 930 |
intent: str # derived or predicted intent
|
| 931 |
result: Optional[str] # tool output, if any
|
| 932 |
|
|
|
|
|
|
|
|
|
|
| 933 |
|
| 934 |
|
| 935 |
|
|
|
|
| 1059 |
"huggingfacehub_api_token": os.getenv("HF_TOKEN")
|
| 1060 |
}
|
| 1061 |
|
| 1062 |
+
# Get LLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1063 |
def get_llm(provider: str, config: dict):
|
| 1064 |
if provider == "huggingface":
|
| 1065 |
from langchain_huggingface import HuggingFaceEndpoint
|
|
|
|
| 1072 |
)
|
| 1073 |
else:
|
| 1074 |
raise ValueError(f"Unsupported provider: {provider}")
|
| 1075 |
+
|
| 1076 |
+
|
| 1077 |
+
def assistant(state: dict):
|
| 1078 |
+
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
| 1079 |
+
|
| 1080 |
+
def tools_condition(state: dict) -> str:
|
| 1081 |
+
if "use tool" in state["messages"][-1].content.lower():
|
| 1082 |
+
return "tools"
|
| 1083 |
+
else:
|
| 1084 |
+
return "END"
|
| 1085 |
+
|
| 1086 |
+
|
| 1087 |
+
from langgraph.graph import StateGraph
|
| 1088 |
+
|
| 1089 |
+
# Build graph using AgentState as the shared schema
|
| 1090 |
+
def build_graph() -> StateGraph:
|
| 1091 |
+
builder = StateGraph(AgentState)
|
| 1092 |
+
|
| 1093 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
| 1094 |
+
#from langgraph.graph import StateGraph, ToolNode
|
| 1095 |
+
from langchain_core.runnables import RunnableLambda
|
| 1096 |
+
from some_module import vector_store # Make sure this is defined/imported
|
| 1097 |
+
|
| 1098 |
+
|
| 1099 |
llm = get_llm(provider, model_config)
|
| 1100 |
|
| 1101 |
# Step 2: Define tools
|
|
|
|
| 1110 |
wikidata_query
|
| 1111 |
]
|
| 1112 |
|
| 1113 |
+
global tool_map
|
| 1114 |
+
tool_map = {t.name: t for t in tools}
|
| 1115 |
+
|
| 1116 |
# Step 3: Bind tools to LLM
|
| 1117 |
llm_with_tools = llm.bind_tools(tools)
|
| 1118 |
|
| 1119 |
# Step 4: Build stateful graph logic
|
| 1120 |
sys_msg = SystemMessage(content="You are a helpful assistant.")
|
| 1121 |
|
| 1122 |
+
# --- Define nodes ---
|
| 1123 |
+
retriever = RunnableLambda(lambda state: {
|
| 1124 |
+
**state,
|
| 1125 |
+
"retrieved_docs": vector_store.similarity_search(state["input"])
|
| 1126 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1127 |
|
| 1128 |
+
assistant = RunnableLambda(lambda state: {
|
| 1129 |
+
**state,
|
| 1130 |
+
"messages": state["messages"] + [SystemMessage(content="You are a helpful assistant.")]
|
| 1131 |
+
})
|
| 1132 |
|
| 1133 |
+
call_llm = llm.bind_tools(tools)
|
| 1134 |
|
| 1135 |
+
# --- Build the graph ---
|
| 1136 |
+
builder = StateGraph(AgentState)
|
| 1137 |
|
| 1138 |
+
builder.add_node("retriever", retriever)
|
| 1139 |
+
builder.add_node("assistant", assistant)
|
| 1140 |
+
builder.add_node("call_llm", call_llm)
|
| 1141 |
+
builder.add_node("call_tool", tool_dispatcher)
|
|
|
|
| 1142 |
|
| 1143 |
+
builder.set_entry_point("retriever")
|
| 1144 |
+
builder.add_edge("retriever", "assistant")
|
| 1145 |
+
builder.add_edge("assistant", "call_llm")
|
| 1146 |
+
builder.add_conditional_edges("call_llm", should_call_tool, {
|
| 1147 |
+
"call_tool": "call_tool",
|
| 1148 |
+
"end": None
|
| 1149 |
+
})
|
| 1150 |
+
builder.add_edge("call_tool", "call_llm")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1151 |
|
| 1152 |
+
graph = builder.compile()
|
| 1153 |
+
return graph
|
|
|
|
| 1154 |
|
| 1155 |
|
| 1156 |
|
| 1157 |
# call build_graph AFTER it’s defined
|
| 1158 |
+
agent = graph
|
| 1159 |
|
| 1160 |
# Now you can use the agent like this:
|
| 1161 |
result = agent.invoke({"messages": [HumanMessage(content=question)]})
|