Update agent.py
Browse files
agent.py
CHANGED
@@ -354,51 +354,44 @@ def wiki_search(query : str) -> str:
|
|
354 |
tools = [weather_tool, wiki_search, web_search,
|
355 |
add, subtract, multiply, divide, square, cube, power, factorial, mean, standard_deviation]
|
356 |
|
|
|
|
|
357 |
llm = ChatGroq(
|
358 |
temperature=0,
|
359 |
-
model_name="qwen-qwq-32b",
|
360 |
groq_api_key=os.getenv("GROQ_API_KEY")
|
361 |
)
|
362 |
|
363 |
-
|
|
|
|
|
364 |
|
365 |
-
|
366 |
|
367 |
-
|
368 |
-
"""Retrieve similar context and inject"""
|
369 |
-
query = state["messages"][0].content
|
370 |
-
similar_docs = vector_store.similarity_search(query)
|
371 |
|
372 |
-
|
373 |
-
|
374 |
-
content=f"Here is a similar question and answer for reference:\n\n{similar_docs[0].page_content}"
|
375 |
-
)
|
376 |
-
return {"messages": [sys_msg] + state["messages"] + [ref_msg]}
|
377 |
-
else:
|
378 |
-
return {"messages": [sys_msg] + state["messages"]}
|
379 |
|
380 |
-
def assistant(state:
|
381 |
-
""
|
382 |
-
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
383 |
|
384 |
-
# ===
|
385 |
|
386 |
def build_graph():
|
387 |
-
builder = StateGraph(
|
388 |
-
builder.add_node("retriever", retriever)
|
389 |
builder.add_node("assistant", assistant)
|
390 |
-
builder.add_node("tools", ToolNode(
|
391 |
|
392 |
-
builder.set_entry_point("
|
393 |
-
builder.add_edge("retriever", "assistant")
|
394 |
builder.add_conditional_edges("assistant", tools_condition)
|
395 |
builder.add_edge("tools", "assistant")
|
396 |
-
|
397 |
return builder.compile()
|
398 |
|
399 |
-
# ===
|
|
|
400 |
if __name__ == "__main__":
|
401 |
-
question = "When
|
402 |
graph = build_graph()
|
403 |
messages = [HumanMessage(content=question)]
|
404 |
result = graph.invoke({"messages": messages})
|
|
|
354 |
tools = [weather_tool, wiki_search, web_search,
|
355 |
add, subtract, multiply, divide, square, cube, power, factorial, mean, standard_deviation]
|
356 |
|
357 |
+
# === LLM with Tools ===
|
358 |
+
|
359 |
llm = ChatGroq(
|
360 |
temperature=0,
|
361 |
+
model_name="qwen-qwq-32b",
|
362 |
groq_api_key=os.getenv("GROQ_API_KEY")
|
363 |
)
|
364 |
|
365 |
+
tools = [weather_tool, wiki_search, web_search,
|
366 |
+
add, subtract, multiply, divide, square, cube,
|
367 |
+
power, factorial, mean, standard_deviation]
|
368 |
|
369 |
+
llm_with_tools = llm.bind_tools(tools)
|
370 |
|
371 |
+
# === LangGraph State ===
|
|
|
|
|
|
|
372 |
|
373 |
+
class ToolAgentState(TypedDict):
|
374 |
+
messages: Annotated[List[HumanMessage], "Messages in the conversation"]
|
|
|
|
|
|
|
|
|
|
|
375 |
|
376 |
+
def assistant(state: ToolAgentState):
|
377 |
+
return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}
|
|
|
378 |
|
379 |
+
# === Build Graph ===
|
380 |
|
381 |
def build_graph():
|
382 |
+
builder = StateGraph(ToolAgentState)
|
|
|
383 |
builder.add_node("assistant", assistant)
|
384 |
+
builder.add_node("tools", ToolNode(tools))
|
385 |
|
386 |
+
builder.set_entry_point("assistant")
|
|
|
387 |
builder.add_conditional_edges("assistant", tools_condition)
|
388 |
builder.add_edge("tools", "assistant")
|
|
|
389 |
return builder.compile()
|
390 |
|
391 |
+
# === Run ===
|
392 |
+
|
393 |
if __name__ == "__main__":
|
394 |
+
question = "When did India won a world cup in cricket before 2000?"
|
395 |
graph = build_graph()
|
396 |
messages = [HumanMessage(content=question)]
|
397 |
result = graph.invoke({"messages": messages})
|