laverdes commited on
Commit
20fac12
·
verified ·
1 Parent(s): fa70e96

feat: basic agent-with-tools workflow

Browse files
Files changed (1) hide show
  1. basic_agent.py +152 -0
basic_agent.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from rich import print as rich_print
4
+ from rich.panel import Panel
5
+ from rich.console import Console
6
+ from rich.pretty import Pretty
7
+ from rich.markdown import Markdown
8
+ from rich.json import JSON
9
+
10
+ from typing import TypedDict, Sequence, Annotated
11
+ from langchain_core.messages import BaseMessage
12
+ from langgraph.graph.message import add_messages
13
+ from langgraph.graph import StateGraph, START, END
14
+ from langchain_openai import ChatOpenAI
15
+ from langgraph.prebuilt import ToolNode, tools_condition
16
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
17
+ from tqdm import tqdm
18
+
19
+
20
+ def print_conversation(messages):
21
+ console = Console(width=200, soft_wrap=True)
22
+
23
+ for msg in messages:
24
+ role = msg.get("role", "unknown").capitalize()
25
+ content = msg.get("content", "")
26
+
27
+ try:
28
+ parsed_json = json.loads(content)
29
+ rendered_content = JSON.from_data(parsed_json)
30
+ except (json.JSONDecodeError, TypeError):
31
+ rendered_content = Markdown(content.strip())
32
+
33
+ panel = Panel(
34
+ rendered_content,
35
+ title=f"[bold blue]{role}[/]",
36
+ border_style="green" if role == "User" else "magenta",
37
+ expand=True
38
+ )
39
+
40
+ console.print(panel)
41
+
42
+
43
+ class AgentState(TypedDict):
44
+ messages: Annotated[Sequence[BaseMessage], add_messages]
45
+
46
+
47
+ class BasicOpenAIAgentWorkflow:
48
+ """Basic custom class from an agent prompted for tool-use pattern"""
49
+
50
+ def __init__(self, tools: list, model='gpt-4o', backstory:str="", streaming=False):
51
+ self.name = "Basic OpenAI Agent Workflow"
52
+ self.tools = tools
53
+ self.llm = ChatOpenAI(model=model, temperature=0, streaming=streaming)
54
+ self.graph = None
55
+ self.history = []
56
+ self.history_messages = [] # Store messages in LangChain format
57
+ self.backstory = backstory if backstory else "You are a helpful assistant that can use tools to answer questions. Your name is Gaia."
58
+
59
+ role_message = {'role': 'system', 'content': self.backstory}
60
+ self.history.append(role_message)
61
+
62
+
63
+ def _call_llm(self, state: AgentState):
64
+ """invokes the assigned llm"""
65
+ return {'messages': [self.llm.invoke(state['messages'])]}
66
+
67
+
68
+ def _convert_history_to_messages(self):
69
+ """Convert self.history to LangChain-compatible messages"""
70
+ converted = []
71
+ for msg in self.history:
72
+ content = msg['content']
73
+
74
+ if not isinstance(content, str):
75
+ raise ValueError(f"Expected string content, got: {type(content)} — {content}")
76
+
77
+ if msg['role'] == 'user':
78
+ converted.append(HumanMessage(content=content))
79
+ elif msg['role'] == 'assistant':
80
+ converted.append(AIMessage(content=content))
81
+ elif msg['role'] == 'system':
82
+ converted.append(SystemMessage(content=content))
83
+ else:
84
+ raise ValueError(f"Unknown role in message: {msg}")
85
+ self.history_messages = converted
86
+
87
+
88
+ def create_basic_tool_use_agent_state_graph(self, custom_tools_nm="tools"):
89
+ """Binds tools, creates and compiles graph"""
90
+ self.llm = self.llm.bind_tools(self.tools)
91
+
92
+ # Graph Init
93
+ graph = StateGraph(AgentState)
94
+
95
+ # Nodes
96
+ graph.add_node('agent', self._call_llm)
97
+ tools_node = ToolNode(self.tools)
98
+ graph.add_node(custom_tools_nm, tools_node)
99
+
100
+ # Edges
101
+ graph.add_edge(START, "agent")
102
+ graph.add_conditional_edges('agent', tools_condition, {'tools': custom_tools_nm, END: END})
103
+
104
+ self.graph = graph.compile()
105
+
106
+
107
+ def chat(self, query, verbose=2):
108
+ """Simple agent call"""
109
+ if isinstance(query, dict):
110
+ query = query["messages"]
111
+
112
+ user_message = {'role': 'user', 'content': query}
113
+ self.history.append(user_message)
114
+
115
+ # Ensure history has at least 1 message
116
+ if not self.history:
117
+ raise ValueError("History is empty. Cannot proceed.")
118
+
119
+ self._convert_history_to_messages()
120
+
121
+ if not self.history_messages:
122
+ raise ValueError("Converted message history is empty. Something went wrong.")
123
+
124
+ response = self.graph.invoke({'messages': self.history_messages}) # invoke with all the history
125
+ response = response['messages'][-1].content
126
+ assistant_message = {'role': 'assistant', 'content': response}
127
+ self.history.append(assistant_message)
128
+
129
+ if verbose==2:
130
+ print_conversation(self.history)
131
+ elif verbose==1:
132
+ print_conversation([response])
133
+
134
+ return response
135
+
136
+
137
+ def invoke(self, input_str: str):
138
+ """Invoke the compiled graph with the input data"""
139
+ _ = self.chat(input_str) # prints response in terminal
140
+ self._convert_history_to_messages()
141
+ return {'messages': self.history_messages}
142
+
143
+
144
+ def chat_batch(self, queries=None):
145
+ """Send several simple agent calls to the llm using the compiled graph"""
146
+ if queries is None:
147
+ queries = []
148
+ for i, query in tqdm(enumerate(queries, start=1)):
149
+ if i == len(queries):
150
+ self.chat(query, verbose=2)
151
+ else:
152
+ self.chat(query, verbose=0)