Spaces:
Sleeping
Sleeping
Daniel Amendoeira
commited on
Update agent.py
Browse files
agent.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
import os
|
2 |
|
3 |
from langchain_openai import ChatOpenAI
|
4 |
-
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
|
|
|
5 |
from langgraph.graph import MessagesState
|
6 |
from langgraph.graph import StateGraph, START, END
|
|
|
7 |
|
8 |
from langchain_community.tools import BraveSearch # web search
|
9 |
from langchain.tools import Calculator # for basic math
|
@@ -29,7 +31,41 @@ search_tool = BraveSearch.from_api_key(
|
|
29 |
)
|
30 |
|
31 |
community_tools = [calculator_tool, python_tool, search_tool]
|
32 |
-
custom_tools = [datetime_tools, transcribe_audio_tool]
|
33 |
|
34 |
tools = community_tools + custom_tools
|
35 |
-
llm_with_tools = llm.bind_tools(tools)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
|
3 |
from langchain_openai import ChatOpenAI
|
4 |
+
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
|
5 |
+
from langgraph.graph.message import add_messages
|
6 |
from langgraph.graph import MessagesState
|
7 |
from langgraph.graph import StateGraph, START, END
|
8 |
+
from typing import Literal
|
9 |
|
10 |
from langchain_community.tools import BraveSearch # web search
|
11 |
from langchain.tools import Calculator # for basic math
|
|
|
31 |
)
|
32 |
|
33 |
community_tools = [calculator_tool, python_tool, search_tool]
|
34 |
+
custom_tools = [datetime_tools, transcribe_audio_tool]
|
35 |
|
36 |
tools = community_tools + custom_tools
|
37 |
+
llm_with_tools = llm.bind_tools(tools)
|
38 |
+
|
39 |
+
tools_by_name = {tool.name: tool for tool in tools}
|
40 |
+
|
41 |
+
class MessagesState(TypedDict): # creates the state (is like the agent's memory at any moment)
|
42 |
+
messages: Annotated[list[AnyMessage], add_messages]
|
43 |
+
|
44 |
+
# LLM node
|
45 |
+
def llm_call(state: MessagesState):
|
46 |
+
return {
|
47 |
+
"messages": [
|
48 |
+
llm_with_tools.invoke(
|
49 |
+
[SystemMessage(content=system_prompt)] + state["messages"]
|
50 |
+
)
|
51 |
+
]
|
52 |
+
}
|
53 |
+
|
54 |
+
# Tool node
|
55 |
+
def tool_node(state: MessagesState):
|
56 |
+
result = []
|
57 |
+
for tool_call in state["messages"][-1].tool_calls:
|
58 |
+
tool = tools_by_name[tool_call["name"]]
|
59 |
+
observation = tool.invoke(tool_call["args"])
|
60 |
+
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
|
61 |
+
return {"messages": result}
|
62 |
+
|
63 |
+
def should_continue(state: MessagesState) -> Literal["Action", END]:
|
64 |
+
"""Decide if we should continue the loop or stop based upon whether the LLM made a tool call"""
|
65 |
+
|
66 |
+
last_message = state["messages"][-1]
|
67 |
+
# If the LLM makes a tool call, then perform an action
|
68 |
+
if last_message.tool_calls:
|
69 |
+
return "Action"
|
70 |
+
# Otherwise, we stop (reply to the user)
|
71 |
+
return END
|