wt002 commited on
Commit
7622d0c
·
verified ·
1 Parent(s): 23ba2f5

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +99 -88
agent.py CHANGED
@@ -509,25 +509,39 @@ def process_question(question):
509
 
510
 
511
 
512
- def retriever(state: MessagesState):
513
- """Retriever node using similarity scores for filtering"""
514
- query = state["messages"][0].content
515
- results = vector_store.similarity_search_with_score(query, k=4) # top 4 matches
516
-
517
- # Dynamically adjust threshold based on query complexity
518
- threshold = 0.75 if "who" in query else 0.8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  filtered = [doc for doc, score in results if score < threshold]
520
 
521
- # Provide a default message if no documents found
522
  if not filtered:
523
- example_msg = HumanMessage(content="No relevant documents found.")
524
  else:
525
  content = "\n\n".join(doc.page_content for doc in filtered)
526
- example_msg = HumanMessage(
527
- content=f"Here are relevant reference documents:\n\n{content}"
528
- )
529
 
530
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
531
 
532
 
533
 
@@ -537,17 +551,29 @@ def retriever(state: MessagesState):
537
  def get_llm(provider: str, config: dict):
538
  if provider == "google":
539
  from langchain_google_genai import ChatGoogleGenerativeAI
540
- return ChatGoogleGenerativeAI(model=config["model"], temperature=config["temperature"])
 
 
 
 
541
 
542
  elif provider == "groq":
543
  from langchain_groq import ChatGroq
544
- return ChatGroq(model=config["model"], temperature=config["temperature"])
 
 
 
 
545
 
546
  elif provider == "huggingface":
547
  from langchain_huggingface import ChatHuggingFace
548
  from langchain_huggingface import HuggingFaceEndpoint
549
  return ChatHuggingFace(
550
- llm=HuggingFaceEndpoint(url=config["url"], temperature=config["temperature"])
 
 
 
 
551
  )
552
 
553
  else:
@@ -607,7 +633,7 @@ def planner(question: str, tools: list) -> tuple:
607
  return detected_intent, matched_tools if matched_tools else [tools[0]]
608
 
609
 
610
- import json
611
 
612
  def task_classifier(question: str) -> str:
613
  """
@@ -889,7 +915,7 @@ def tool_dispatcher(state: AgentState) -> AgentState:
889
 
890
 
891
  # Decide what to do next: if tool call → call_tool, else → end
892
- def call_tool(state):
893
  last_msg = state["messages"][-1]
894
  if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
895
  return "call_tool"
@@ -904,9 +930,6 @@ class AgentState(TypedDict):
904
  intent: str # derived or predicted intent
905
  result: Optional[str] # tool output, if any
906
 
907
- builder.add_node("call_tool", tool_dispatcher)
908
-
909
-
910
 
911
 
912
 
@@ -1036,14 +1059,7 @@ model_config = {
1036
  "huggingfacehub_api_token": os.getenv("HF_TOKEN")
1037
  }
1038
 
1039
-
1040
- def build_graph(provider, model_config):
1041
- from langchain_core.messages import SystemMessage, HumanMessage
1042
- #from langgraph.graph import StateGraph, ToolNode
1043
- from langchain_core.runnables import RunnableLambda
1044
- from some_module import vector_store # Make sure this is defined/imported
1045
-
1046
- # Step 1: Get LLM
1047
  def get_llm(provider: str, config: dict):
1048
  if provider == "huggingface":
1049
  from langchain_huggingface import HuggingFaceEndpoint
@@ -1056,7 +1072,30 @@ def build_graph(provider, model_config):
1056
  )
1057
  else:
1058
  raise ValueError(f"Unsupported provider: {provider}")
1059
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1060
  llm = get_llm(provider, model_config)
1061
 
1062
  # Step 2: Define tools
@@ -1071,80 +1110,52 @@ def build_graph(provider, model_config):
1071
  wikidata_query
1072
  ]
1073
 
 
 
 
1074
  # Step 3: Bind tools to LLM
1075
  llm_with_tools = llm.bind_tools(tools)
1076
 
1077
  # Step 4: Build stateful graph logic
1078
  sys_msg = SystemMessage(content="You are a helpful assistant.")
1079
 
1080
- def retriever(state: dict):
1081
- user_query = state["messages"][0].content
1082
- similar_docs = vector_store.similarity_search(user_query)
1083
-
1084
- if not similar_docs:
1085
- wiki_result = wiki_search.run(user_query)
1086
- return {
1087
- "messages": [
1088
- sys_msg,
1089
- state["messages"][0],
1090
- HumanMessage(content=f"Using Wikipedia search:\n\n{wiki_result}")
1091
- ]
1092
- }
1093
- else:
1094
- return {
1095
- "messages": [
1096
- sys_msg,
1097
- state["messages"][0],
1098
- HumanMessage(content=f"Reference:\n\n{similar_docs[0].page_content}")
1099
- ]
1100
- }
1101
-
1102
- def assistant(state: dict):
1103
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
1104
-
1105
- def tools_condition(state: dict) -> str:
1106
- if "use tool" in state["messages"][-1].content.lower():
1107
- return "tools"
1108
- else:
1109
- return "END"
1110
-
1111
 
 
 
 
 
1112
 
1113
- from langgraph.graph import StateGraph
1114
 
1115
- # Build graph using AgentState as the shared schema
1116
- builder = StateGraph(AgentState)
1117
 
1118
- # Add nodes
1119
- builder.add_node("retriever", retriever)
1120
- builder.add_node("assistant", assistant)
1121
- builder.add_node("call_llm", call_llm)
1122
- builder.add_node("call_tool", tool_dispatcher) # one name is enough
1123
 
1124
- # Entry point
1125
- builder.set_entry_point("retriever")
1126
-
1127
- # Define the flow
1128
- builder.add_edge("retriever", "assistant")
1129
- builder.add_edge("assistant", "call_llm")
1130
-
1131
- # Conditional edge from LLM to tool or end
1132
- builder.add_conditional_edges("call_llm", should_call_tool, {
1133
- "call_tool": "call_tool",
1134
- "end": None
1135
- })
1136
-
1137
- # Loop back after tool execution
1138
- builder.add_edge("call_tool", "call_llm")
1139
 
1140
- # Compile
1141
- graph = builder.compile()
1142
- return graph
1143
 
1144
 
1145
 
1146
  # call build_graph AFTER it’s defined
1147
- agent = build_graph(provider, model_config)
1148
 
1149
  # Now you can use the agent like this:
1150
  result = agent.invoke({"messages": [HumanMessage(content=question)]})
 
509
 
510
 
511
 
512
+ from langchain.schema import HumanMessage
513
+
514
+ def retriever(state: MessagesState, k: int = 4):
515
+ """
516
+ Retrieves documents from the vector store using similarity scores,
517
+ applies a dynamic threshold filter, and returns updated message state.
518
+
519
+ Args:
520
+ state (MessagesState): Current message state including the user's query.
521
+ k (int): Number of top results to retrieve from the vector store.
522
+
523
+ Returns:
524
+ dict: Updated messages state including relevant documents or fallback message.
525
+ """
526
+ query = state["messages"][0].content.strip()
527
+ results = vector_store.similarity_search_with_score(query, k=k)
528
+
529
+ # Determine dynamic similarity threshold
530
+ if any(keyword in query.lower() for keyword in ["who", "what", "where", "when", "why", "how"]):
531
+ threshold = 0.75
532
+ else:
533
+ threshold = 0.8
534
+
535
  filtered = [doc for doc, score in results if score < threshold]
536
 
 
537
  if not filtered:
538
+ response_msg = HumanMessage(content="No relevant documents found.")
539
  else:
540
  content = "\n\n".join(doc.page_content for doc in filtered)
541
+ response_msg = HumanMessage(content=f"Here are relevant reference documents:\n\n{content}")
542
+
543
+ return {"messages": [sys_msg] + state["messages"] + [response_msg]}
544
 
 
545
 
546
 
547
 
 
551
  def get_llm(provider: str, config: dict):
552
  if provider == "google":
553
  from langchain_google_genai import ChatGoogleGenerativeAI
554
+ return ChatGoogleGenerativeAI(
555
+ model=config.get("model"),
556
+ temperature=config.get("temperature", 0.7),
557
+ google_api_key=config.get("api_key") # Optional: if needed
558
+ )
559
 
560
  elif provider == "groq":
561
  from langchain_groq import ChatGroq
562
+ return ChatGroq(
563
+ model=config.get("model"),
564
+ temperature=config.get("temperature", 0.7),
565
+ groq_api_key=config.get("api_key") # Optional: if needed
566
+ )
567
 
568
  elif provider == "huggingface":
569
  from langchain_huggingface import ChatHuggingFace
570
  from langchain_huggingface import HuggingFaceEndpoint
571
  return ChatHuggingFace(
572
+ llm=HuggingFaceEndpoint(
573
+ endpoint_url=config.get("url"),
574
+ temperature=config.get("temperature", 0.7),
575
+ huggingfacehub_api_token=config.get("api_key") # Optional
576
+ )
577
  )
578
 
579
  else:
 
633
  return detected_intent, matched_tools if matched_tools else [tools[0]]
634
 
635
 
636
+
637
 
638
  def task_classifier(question: str) -> str:
639
  """
 
915
 
916
 
917
  # Decide what to do next: if tool call → call_tool, else → end
918
+ def should_call_tool(state):
919
  last_msg = state["messages"][-1]
920
  if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
921
  return "call_tool"
 
930
  intent: str # derived or predicted intent
931
  result: Optional[str] # tool output, if any
932
 
 
 
 
933
 
934
 
935
 
 
1059
  "huggingfacehub_api_token": os.getenv("HF_TOKEN")
1060
  }
1061
 
1062
+ # Get LLM
 
 
 
 
 
 
 
1063
  def get_llm(provider: str, config: dict):
1064
  if provider == "huggingface":
1065
  from langchain_huggingface import HuggingFaceEndpoint
 
1072
  )
1073
  else:
1074
  raise ValueError(f"Unsupported provider: {provider}")
1075
+
1076
+
1077
+ def assistant(state: dict):
1078
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
1079
+
1080
+ def tools_condition(state: dict) -> str:
1081
+ if "use tool" in state["messages"][-1].content.lower():
1082
+ return "tools"
1083
+ else:
1084
+ return "END"
1085
+
1086
+
1087
+ from langgraph.graph import StateGraph
1088
+
1089
+ # Build graph using AgentState as the shared schema
1090
+ def build_graph() -> StateGraph:
1091
+ builder = StateGraph(AgentState)
1092
+
1093
+ from langchain_core.messages import SystemMessage, HumanMessage
1094
+ #from langgraph.graph import StateGraph, ToolNode
1095
+ from langchain_core.runnables import RunnableLambda
1096
+ from some_module import vector_store # Make sure this is defined/imported
1097
+
1098
+
1099
  llm = get_llm(provider, model_config)
1100
 
1101
  # Step 2: Define tools
 
1110
  wikidata_query
1111
  ]
1112
 
1113
+ global tool_map
1114
+ tool_map = {t.name: t for t in tools}
1115
+
1116
  # Step 3: Bind tools to LLM
1117
  llm_with_tools = llm.bind_tools(tools)
1118
 
1119
  # Step 4: Build stateful graph logic
1120
  sys_msg = SystemMessage(content="You are a helpful assistant.")
1121
 
1122
+ # --- Define nodes ---
1123
+ retriever = RunnableLambda(lambda state: {
1124
+ **state,
1125
+ "retrieved_docs": vector_store.similarity_search(state["input"])
1126
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1127
 
1128
+ assistant = RunnableLambda(lambda state: {
1129
+ **state,
1130
+ "messages": state["messages"] + [SystemMessage(content="You are a helpful assistant.")]
1131
+ })
1132
 
1133
+ call_llm = llm.bind_tools(tools)
1134
 
1135
+ # --- Build the graph ---
1136
+ builder = StateGraph(AgentState)
1137
 
1138
+ builder.add_node("retriever", retriever)
1139
+ builder.add_node("assistant", assistant)
1140
+ builder.add_node("call_llm", call_llm)
1141
+ builder.add_node("call_tool", tool_dispatcher)
 
1142
 
1143
+ builder.set_entry_point("retriever")
1144
+ builder.add_edge("retriever", "assistant")
1145
+ builder.add_edge("assistant", "call_llm")
1146
+ builder.add_conditional_edges("call_llm", should_call_tool, {
1147
+ "call_tool": "call_tool",
1148
+ "end": None
1149
+ })
1150
+ builder.add_edge("call_tool", "call_llm")
 
 
 
 
 
 
 
1151
 
1152
+ graph = builder.compile()
1153
+ return graph
 
1154
 
1155
 
1156
 
1157
  # call build_graph AFTER it’s defined
1158
+ agent = graph
1159
 
1160
  # Now you can use the agent like this:
1161
  result = agent.invoke({"messages": [HumanMessage(content=question)]})