YigitSekerci commited on
Commit
903ecf8
·
1 Parent(s): 4a4128c

implement basic agent

Browse files
Files changed (1) hide show
  1. src/agent.py +367 -0
src/agent.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ from typing import List, Dict, Any, Optional, Tuple, Union
4
+ from langchain_mcp_adapters.client import MultiServerMCPClient
5
+ from langgraph.prebuilt import create_react_agent
6
+ from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
7
+ from langchain_core.output_parsers import StrOutputParser
8
+ from langchain_core.exceptions import OutputParserException
9
+ from dotenv import load_dotenv
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ load_dotenv()
16
+
17
+ class AudioAgentError(Exception):
18
+ """Custom exception for AudioAgent errors"""
19
+ pass
20
+
21
+
22
+ class AudioAgentInitializationError(AudioAgentError):
23
+ """Raised when agent initialization fails"""
24
+ pass
25
+
26
+
27
+ class AudioAgentChatError(AudioAgentError):
28
+ """Raised when chat processing fails"""
29
+ pass
30
+
31
+
32
+ class AudioAgent:
33
+ """
34
+ A class to manage an audio-focused AI agent with MCP tools integration.
35
+
36
+ This agent connects to audio tools via MCP and provides a conversational interface
37
+ using LangChain's robust message handling and output parsing.
38
+ """
39
+
40
+ def __init__(self, model_name: str = "gpt-4o", server_url: str = "http://127.0.0.1:7860/gradio_api/mcp/sse"):
41
+ """
42
+ Initialize the AudioAgent.
43
+
44
+ Args:
45
+ model_name: The language model to use for the agent
46
+ server_url: The URL of the MCP server providing audio tools
47
+ """
48
+ self.model_name = model_name
49
+ self.server_url = server_url
50
+ self._agent = None
51
+ self._tools = None
52
+ self._is_initialized = False
53
+ self._output_parser = StrOutputParser()
54
+
55
+ # Initialize MCP client
56
+ self._client = MultiServerMCPClient({
57
+ "audio-tools": {
58
+ "url": server_url,
59
+ "transport": "sse",
60
+ }
61
+ })
62
+
63
+ @property
64
+ def is_initialized(self) -> bool:
65
+ """Check if the agent is initialized and ready to use."""
66
+ return self._is_initialized
67
+
68
+ async def initialize(self) -> None:
69
+ """
70
+ Initialize the agent with tools from the MCP client.
71
+
72
+ Raises:
73
+ AudioAgentInitializationError: If initialization fails
74
+ """
75
+ if self._is_initialized:
76
+ logger.info("Agent already initialized")
77
+ return
78
+
79
+ try:
80
+ logger.info("Initializing AudioAgent...")
81
+
82
+ # Get tools from MCP client
83
+ self._tools = await self._client.get_tools()
84
+ if not self._tools:
85
+ raise AudioAgentInitializationError("No tools available from MCP client")
86
+
87
+ logger.info(f"Loaded {len(self._tools)} tools: {[tool.name for tool in self._tools]}")
88
+
89
+ # Create the agent
90
+ self._agent = create_react_agent(
91
+ self.model_name,
92
+ self._tools,
93
+ )
94
+
95
+ self._is_initialized = True
96
+ logger.info("AudioAgent initialized successfully")
97
+
98
+ except Exception as e:
99
+ error_msg = f"Failed to initialize AudioAgent: {str(e)}"
100
+ logger.error(error_msg)
101
+ raise AudioAgentInitializationError(error_msg) from e
102
+
103
+ def _convert_to_langchain_messages(self, history: List[Tuple[str, Optional[str]]]) -> List[BaseMessage]:
104
+ """
105
+ Convert chat history to LangChain message objects.
106
+
107
+ Args:
108
+ history: List of (human_message, ai_response) tuples
109
+
110
+ Returns:
111
+ List of LangChain BaseMessage objects
112
+ """
113
+ messages = []
114
+ for human_msg, ai_msg in history:
115
+ if human_msg and human_msg.strip():
116
+ messages.append(HumanMessage(content=human_msg.strip()))
117
+ if ai_msg and ai_msg.strip():
118
+ messages.append(AIMessage(content=ai_msg.strip()))
119
+ return messages
120
+
121
+ def _format_messages_for_agent(self, messages: List[BaseMessage]) -> List[Dict[str, str]]:
122
+ """
123
+ Convert LangChain messages to the format expected by the agent.
124
+
125
+ Args:
126
+ messages: List of LangChain BaseMessage objects
127
+
128
+ Returns:
129
+ List of message dictionaries with role and content
130
+ """
131
+ formatted_messages = []
132
+ for message in messages:
133
+ if isinstance(message, HumanMessage):
134
+ formatted_messages.append({"role": "user", "content": message.content})
135
+ elif isinstance(message, AIMessage):
136
+ formatted_messages.append({"role": "assistant", "content": message.content})
137
+ else:
138
+ # Handle other message types if needed
139
+ formatted_messages.append({"role": "user", "content": str(message.content)})
140
+ return formatted_messages
141
+
142
+ async def _extract_response_content(self, response: Dict[str, Any]) -> str:
143
+ """
144
+ Extract the content from the agent's response using LangChain output parser.
145
+
146
+ Args:
147
+ response: The response from the agent
148
+
149
+ Returns:
150
+ The extracted content as a string
151
+
152
+ Raises:
153
+ AudioAgentChatError: If response parsing fails
154
+ """
155
+ try:
156
+ if not response:
157
+ raise OutputParserException("Received empty response from agent")
158
+
159
+ if "messages" not in response or not response["messages"]:
160
+ raise OutputParserException("No messages found in agent response")
161
+
162
+ last_message = response["messages"][-1]
163
+
164
+ # Handle different message formats
165
+ if hasattr(last_message, 'content'):
166
+ content = last_message.content
167
+ elif isinstance(last_message, dict) and 'content' in last_message:
168
+ content = last_message['content']
169
+ else:
170
+ content = str(last_message)
171
+
172
+ # Use LangChain's output parser for robust string processing
173
+ parsed_content = await self._output_parser.aparse(content)
174
+ return parsed_content if parsed_content else "I couldn't generate a response."
175
+
176
+ except OutputParserException as e:
177
+ logger.warning(f"Output parsing failed: {e}")
178
+ raise AudioAgentChatError(f"Failed to parse agent response: {str(e)}") from e
179
+ except Exception as e:
180
+ logger.error(f"Unexpected error in response extraction: {e}")
181
+ raise AudioAgentChatError(f"Error extracting response content: {str(e)}") from e
182
+
183
+ def _validate_message(self, message: str) -> str:
184
+ """
185
+ Validate and sanitize the input message.
186
+
187
+ Args:
188
+ message: The user's message
189
+
190
+ Returns:
191
+ The validated and sanitized message
192
+
193
+ Raises:
194
+ AudioAgentChatError: If message is invalid
195
+ """
196
+ if not message:
197
+ raise AudioAgentChatError("Message cannot be None")
198
+
199
+ cleaned_message = message.strip()
200
+ if not cleaned_message:
201
+ raise AudioAgentChatError("Message cannot be empty or only whitespace")
202
+
203
+ if len(cleaned_message) > 10000:
204
+ raise AudioAgentChatError("Message is too long (max 10,000 characters)")
205
+
206
+ return cleaned_message
207
+
208
+ async def chat(self, message: str, history: Optional[List[Tuple[str, Optional[str]]]] = None) -> str:
209
+ """
210
+ Process a chat message with the agent using LangChain's robust message handling.
211
+
212
+ Args:
213
+ message: The user's message
214
+ history: Previous chat history as list of (human, ai) tuples
215
+
216
+ Returns:
217
+ The agent's response
218
+
219
+ Raises:
220
+ AudioAgentChatError: If chat processing fails
221
+ AudioAgentInitializationError: If agent is not initialized
222
+ """
223
+ # Validate input
224
+ validated_message = self._validate_message(message)
225
+
226
+ # Ensure agent is initialized
227
+ if not self._is_initialized:
228
+ await self.initialize()
229
+
230
+ try:
231
+ # Convert history to LangChain messages
232
+ langchain_messages = self._convert_to_langchain_messages(history or [])
233
+
234
+ # Add current message
235
+ langchain_messages.append(HumanMessage(content=validated_message))
236
+
237
+ # Format for agent
238
+ formatted_messages = self._format_messages_for_agent(langchain_messages)
239
+
240
+ logger.info(f"Processing message: {validated_message[:50]}{'...' if len(validated_message) > 50 else ''}")
241
+
242
+ # Get response from agent
243
+ response = await self._agent.ainvoke({"messages": formatted_messages})
244
+
245
+ # Extract and return content using output parser
246
+ content = await self._extract_response_content(response)
247
+ logger.info("Message processed successfully")
248
+ return content
249
+
250
+ except AudioAgentChatError:
251
+ # Re-raise our custom errors
252
+ raise
253
+ except Exception as e:
254
+ error_msg = f"Failed to process chat message: {str(e)}"
255
+ logger.error(error_msg)
256
+ raise AudioAgentChatError(error_msg) from e
257
+
258
+ def chat_sync(self, message: str, history: Optional[List[Tuple[str, Optional[str]]]] = None) -> str:
259
+ """
260
+ Synchronous wrapper for the async chat method.
261
+
262
+ Args:
263
+ message: The user's message
264
+ history: Previous chat history as list of (human, ai) tuples
265
+
266
+ Returns:
267
+ The agent's response
268
+ """
269
+ try:
270
+ return asyncio.run(self.chat(message, history))
271
+ except Exception as e:
272
+ logger.error(f"Error in synchronous chat: {e}")
273
+ raise
274
+
275
+ async def get_available_tools(self) -> List[str]:
276
+ """
277
+ Get the list of available tool names.
278
+
279
+ Returns:
280
+ List of tool names
281
+
282
+ Raises:
283
+ AudioAgentInitializationError: If initialization fails
284
+ """
285
+ try:
286
+ if not self._is_initialized:
287
+ await self.initialize()
288
+ return [tool.name for tool in self._tools] if self._tools else []
289
+ except Exception as e:
290
+ error_msg = f"Failed to get available tools: {str(e)}"
291
+ logger.error(error_msg)
292
+ raise AudioAgentInitializationError(error_msg) from e
293
+
294
+ async def stream_chat(self, message: str, history: Optional[List[Tuple[str, Optional[str]]]] = None):
295
+ """
296
+ Stream a chat response (if supported by the underlying agent).
297
+
298
+ Args:
299
+ message: The user's message
300
+ history: Previous chat history as list of (human, ai) tuples
301
+
302
+ Yields:
303
+ Chunks of the response as they become available
304
+
305
+ Raises:
306
+ AudioAgentChatError: If streaming fails
307
+ """
308
+ # Validate input
309
+ validated_message = self._validate_message(message)
310
+
311
+ # Ensure agent is initialized
312
+ if not self._is_initialized:
313
+ await self.initialize()
314
+
315
+ try:
316
+ # Convert history to LangChain messages
317
+ langchain_messages = self._convert_to_langchain_messages(history or [])
318
+
319
+ # Add current message
320
+ langchain_messages.append(HumanMessage(content=validated_message))
321
+
322
+ # Format for agent
323
+ formatted_messages = self._format_messages_for_agent(langchain_messages)
324
+
325
+ logger.info(f"Streaming message: {validated_message[:50]}{'...' if len(validated_message) > 50 else ''}")
326
+
327
+ # Check if agent supports streaming
328
+ if hasattr(self._agent, 'astream'):
329
+ async for chunk in self._agent.astream({"messages": formatted_messages}):
330
+ yield chunk
331
+ else:
332
+ # Fallback to regular chat if streaming not supported
333
+ response = await self.chat(validated_message, history)
334
+ yield response
335
+
336
+ except Exception as e:
337
+ error_msg = f"Failed to stream chat message: {str(e)}"
338
+ logger.error(error_msg)
339
+ raise AudioAgentChatError(error_msg) from e
340
+
341
+ async def main():
342
+ """Example usage and testing"""
343
+ try:
344
+ # Create and initialize agent
345
+ agent = AudioAgent()
346
+ await agent.initialize()
347
+
348
+ # Show available tools
349
+ tools = await agent.get_available_tools()
350
+ print(f"Available tools: {tools}")
351
+
352
+ # Test chat
353
+ response = await agent.chat("What tools do you have?")
354
+ print(f"Agent response: {response}")
355
+
356
+ # Test streaming (if supported)
357
+ print("\nTesting streaming:")
358
+ async for chunk in agent.stream_chat("Tell me about audio processing"):
359
+ print(f"Chunk: {chunk}")
360
+
361
+ except AudioAgentError as e:
362
+ logger.error(f"AudioAgent error: {e}")
363
+ except Exception as e:
364
+ logger.error(f"Unexpected error: {e}")
365
+
366
+ if __name__ == "__main__":
367
+ asyncio.run(main())