Spaces:
Running
Running
import os | |
from typing import List, Dict, Any, Optional | |
from openai import OpenAI | |
import json | |
from tools import SearchTool, FetchTool, SummarizeTool | |
from dotenv import load_dotenv | |
import httpx | |
from mcp.server.fastmcp import FastMCP | |
from openai.types.chat import ChatCompletionMessage | |
from openai.types.chat.chat_completion import ChatCompletion | |
# mcp = FastMCP("researcher") | |
load_dotenv() | |
class ReActAgent: | |
def __init__(self, client): | |
self.client = client | |
self.model = "qwen-3-32b" | |
self.conversation_history: List[Dict[str, str]] = [] | |
self.max_history_length = 10 # Limit conversation history | |
self.tools = [ | |
SearchTool(), | |
FetchTool(), | |
SummarizeTool() | |
] | |
self.tools_json = [ | |
{ | |
"type": "function", | |
"function": tool.to_json() | |
} | |
for tool in self.tools | |
] | |
self.tools_map = {tool.name: tool for tool in self.tools} | |
self.process_log = [] # Store the intermediate process | |
def _execute_tool(self, tool_call: Dict[str, Any]) -> str: | |
"""Execute the called tool and return the result.""" | |
try: | |
tool_name = tool_call.function.name | |
arguments = json.loads(tool_call.function.arguments) | |
if tool_name not in self.tools_map: | |
return f"Error: Unknown tool: {tool_name}" | |
tool = self.tools_map[tool_name] | |
result = tool(**arguments) | |
# Log the tool execution | |
self.process_log.append({ | |
"tool": tool_name, | |
"arguments": arguments, | |
"result": result | |
}) | |
return result | |
except json.JSONDecodeError: | |
error_msg = "Error: Invalid tool arguments format" | |
self.process_log.append({ | |
"tool": tool_call.function.name, | |
"arguments": tool_call.function.arguments, | |
"result": error_msg | |
}) | |
return error_msg | |
except Exception as e: | |
error_msg = f"Error executing tool: {str(e)}" | |
self.process_log.append({ | |
"tool": tool_call.function.name, | |
"arguments": tool_call.function.arguments, | |
"result": error_msg | |
}) | |
return error_msg | |
def _truncate_history(self): | |
"""Keep only the most recent messages to prevent context overflow.""" | |
if len(self.conversation_history) > self.max_history_length: | |
self.conversation_history = self.conversation_history[-self.max_history_length:] | |
def _format_process_log(self) -> str: | |
"""Format the process log into a readable string.""" | |
if not self.process_log: | |
return "No intermediate steps were taken." | |
formatted_log = ["<intermediate_steps>"] | |
for i, step in enumerate(self.process_log, 1): | |
formatted_log.append(f"\nStep {i}:") | |
formatted_log.append(f"Tool: {step['tool']}") | |
formatted_log.append(f"Arguments: {json.dumps(step['arguments'], indent=2)}") | |
formatted_log.append(f"Result: {step['result']}") | |
formatted_log.append("</intermediate_steps>") | |
return "\n".join(formatted_log) | |
def run(self, user_input: str) -> str: | |
"""Run the ReAct loop for a single user input.""" | |
if not user_input or not isinstance(user_input, str): | |
return "Error: Invalid input. Please provide a valid string query." | |
try: | |
# Reset process log for new query | |
self.process_log = [] | |
# Add user input to conversation history | |
self.conversation_history.append({"role": "user", "content": user_input}) | |
print(f"\n\nUser input: {user_input}\n--------------------------------\n") | |
while True: | |
try: | |
# Get response from the model | |
response: ChatCompletion = self.client.chat.completions.create( | |
model=self.model, | |
messages=self.conversation_history, | |
tools=self.tools_json, | |
) | |
message: ChatCompletionMessage = response.choices[0].message | |
# Add assistant's response to conversation history | |
self.conversation_history.append({ | |
"role": "assistant", | |
"content": message.content if message.content else "", | |
"tool_calls": message.tool_calls | |
}) | |
# If no tool calls, return the response with process log | |
if not message.tool_calls: | |
print("No tool calls\nExiting loop\n--------------------------------") | |
final_response = message.content or "No response generated" | |
process_log = self._format_process_log() | |
return f"{process_log}\n\n{final_response}" | |
# Execute the tool calls | |
tool_results = [] | |
for tool_call in message.tool_calls: | |
print(f"Tool call: {tool_call.function.name}\nTool arguments: {tool_call.function.arguments}") | |
tool_result = self._execute_tool(tool_call) | |
print(f"Tool result: {tool_result}\n--------------------------------\n") | |
tool_results.append({ | |
"tool_call_id": tool_call.id, | |
"role": "tool", | |
"name": tool_call.function.name, | |
"content": tool_result | |
}) | |
# Add tool results to conversation history | |
self.conversation_history.extend(tool_results) | |
self._truncate_history() | |
except Exception as e: | |
error_msg = f"Error during model interaction: {str(e)}" | |
process_log = self._format_process_log() | |
return f"{error_msg}\n\n{process_log}" | |
except Exception as e: | |
error_msg = f"Error in research process: {str(e)}" | |
process_log = self._format_process_log() | |
return f"{error_msg}\n\n{process_log}" | |
# @mcp.tool() | |
def research(query: str) -> str: | |
"""Get final answer on the query after detailed research""" | |
try: | |
api_key = os.environ.get("CEREBRAS_API_KEY") | |
if not api_key: | |
return "Error: Please set CEREBRAS_API_KEY environment variable" | |
client = OpenAI( | |
base_url="https://api.cerebras.ai/v1", | |
api_key=api_key | |
) | |
agent = ReActAgent(client) | |
return agent.run(query) | |
except Exception as e: | |
return f"Error in research function: {str(e)}" | |
# if __name__ == "__main__": | |
# mcp.run() | |