Spaces:
Runtime error
Runtime error
from typing import List, TypedDict, Dict, Any, Literal | |
from langgraph.graph import StateGraph, START, END | |
from langgraph.types import Command | |
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage | |
from langchain_anthropic import ChatAnthropic | |
from langchain_core.tools import tool | |
from langchain_core.prompts import ChatPromptTemplate | |
from langgraph.prebuilt import ToolNode | |
import os | |
from dotenv import load_dotenv | |
from datetime import datetime | |
from tavily import TavilyClient | |
from langfuse.callback import CallbackHandler | |
import requests | |
import json | |
import time | |
from daytona_sdk import Daytona, DaytonaConfig | |
import yt_dlp | |
import io | |
import os | |
import tempfile | |
from pathlib import Path | |
from litellm.models import LiteLLMModel | |
import os | |
# Load environment variablesTuple | |
load_dotenv() | |
# Define the state schema with messages that ToolNode can use | |
class AgentState(TypedDict): | |
messages: List | |
current_question: str | |
final_answer: str | |
validation_result: str | |
worker_iterations: int | |
supervisor_satisfaction: bool | |
validator_approval: bool | |
# Define tools following Langgraph guide | |
def search_web_tavily(query: str) -> str: | |
"""Search the web for information using the Tavily search API.""" | |
# Initialize the Tavily client with API key from environment variables | |
client = TavilyClient(os.getenv("TAVILY_API_KEY")) | |
# Perform the search | |
response = client.search(query=query) | |
# Process the results into a readable format | |
results = [] | |
for i, result in enumerate(response.get("results", []), 1): | |
results.append(f"{i}. {result.get('title')}\n URL: {result.get('url')}\n {result.get('content')}\n") | |
# Format the final response | |
formatted_response = f"Search results for '{query}':\n\n" + "\n".join(results) | |
return formatted_response | |
def search_web_serper(query: str, result_limit: int = 5, search_type: str = "search") -> str: | |
"""Search the web for information using the Serper.dev API. | |
This tool provides comprehensive search results including: | |
1. Knowledge Graph data when available (title, description, attributes) | |
2. Organic search results (titles, links, snippets) | |
3. Related questions from "People Also Ask" section | |
4. Top stories and news articles related to the query | |
It's particularly useful for gathering factual information, current events, | |
and general knowledge from across the web. The results are formatted in a | |
readable structure with clear sections. | |
Parameters: | |
- query: The search query string | |
- result_limit: Maximum number of results to return per section (default: 5) | |
- search_type: Type of search ('search', 'news', 'places', 'images', 'shopping') | |
""" | |
# API URL and headers setup | |
url = "https://google.serper.dev/search" | |
headers = { | |
'X-API-KEY': os.getenv("SERPER_API_KEY"), | |
'Content-Type': 'application/json' | |
} | |
# Prepare the payload with the query and search type | |
payload = json.dumps({ | |
"q": query, | |
"type": search_type | |
}) | |
try: | |
# Make the API request | |
response = requests.request("POST", url, headers=headers, data=payload, timeout=30) | |
response.raise_for_status() # Raise exception for HTTP errors | |
# Parse the JSON response | |
data = response.json() | |
# Format the results | |
results = [] | |
# Add knowledge graph if available | |
if "knowledgeGraph" in data: | |
kg = data["knowledgeGraph"] | |
results.append(f"Knowledge Graph:\n{kg.get('title', 'Unknown')} - {kg.get('type', '')}") | |
results.append(f"Description: {kg.get('description', 'No description available')}") | |
if "attributes" in kg: | |
results.append("Attributes:") | |
for key, value in kg["attributes"].items(): | |
results.append(f"- {key}: {value}") | |
results.append("") # Empty line for separation | |
# Add organic search results | |
if "organic" in data: | |
results.append("Organic Search Results:") | |
for i, result in enumerate(data["organic"][:result_limit], 1): | |
results.append(f"{i}. {result.get('title', 'No title')}") | |
results.append(f" URL: {result.get('link', 'No link')}") | |
results.append(f" {result.get('snippet', 'No snippet')}") | |
results.append("") # Empty line for separation | |
# Add people also ask if available | |
if "peopleAlsoAsk" in data and data["peopleAlsoAsk"]: | |
results.append("People Also Ask:") | |
for i, qa in enumerate(data["peopleAlsoAsk"][:min(3, result_limit)], 1): | |
results.append(f"{i}. Q: {qa.get('question', 'No question')}") | |
results.append(f" A: {qa.get('snippet', 'No answer')}") | |
results.append("") # Empty line for separation | |
# Add top stories if available | |
if "topStories" in data and data["topStories"]: | |
results.append("Top Stories:") | |
for i, story in enumerate(data["topStories"][:min(3, result_limit)], 1): | |
results.append(f"{i}. {story.get('title', 'No title')}") | |
results.append(f" Source: {story.get('source', 'Unknown source')}") | |
if "date" in story: | |
results.append(f" Published: {story.get('date')}") | |
results.append(f" URL: {story.get('link', 'No link')}") | |
results.append("") # Empty line for separation | |
# Format the final response | |
formatted_response = f"Search results for '{query}':\n\n" + "\n".join(results) | |
return formatted_response | |
except requests.exceptions.Timeout: | |
return f"Error: Request to Serper API timed out after 30 seconds" | |
except requests.exceptions.RequestException as e: | |
return f"Error making request to Serper API: {str(e)}" | |
except json.JSONDecodeError: | |
return f"Error: Received invalid JSON response from Serper API" | |
except Exception as e: | |
return f"Error processing search results: {str(e)}" | |
# Initialize a global Daytona sandbox for reuse | |
_daytona_sandbox = None | |
def execute_code_securely(code: str, language: str = "python", timeout: int = 300) -> str: | |
"""Execute code securely in an isolated sandbox environment using Daytona. | |
This tool runs code in a secure, isolated environment to prevent security issues. | |
It's particularly useful for solving computational problems, data processing tasks, | |
mathematical calculations, and other scenarios where code execution is needed. | |
The tool supports multiple languages, with Python as the default. | |
Parameters: | |
- code: The code to execute | |
- language: The programming language (default: "python") | |
- timeout: Maximum execution time in seconds (default: 30) | |
Returns: | |
- The execution result or error message | |
""" | |
global _daytona_sandbox | |
try: | |
# Initialize Daytona client if not already done | |
if _daytona_sandbox is None: | |
api_key = os.getenv("DAYTONA_API_KEY") | |
if not api_key: | |
return "Error: DAYTONA_API_KEY environment variable not set" | |
# Initialize the Daytona client and create a sandbox | |
config = DaytonaConfig(api_key=api_key) | |
daytona_client = Daytona(config) | |
_daytona_sandbox = daytona_client.create() | |
# Execute the code based on the specified language | |
if language.lower() == "python": | |
response = _daytona_sandbox.process.code_run(code, timeout=timeout) | |
else: | |
# For non-Python languages, create a temporary file and execute it | |
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | |
file_extension = { | |
"javascript": "js", | |
"nodejs": "js", | |
"ruby": "rb", | |
"php": "php", | |
"bash": "sh", | |
"shell": "sh", | |
"powershell": "ps1", | |
"c": "c", | |
"cpp": "cpp", | |
"java": "java", | |
"go": "go", | |
"rust": "rs", | |
}.get(language.lower(), "txt") | |
filename = f"/tmp/code_{timestamp}.{file_extension}" | |
# Upload the code file to the sandbox | |
_daytona_sandbox.fs.upload_file(filename, code.encode('utf-8')) | |
# Prepare the execution command based on language | |
exec_cmd = { | |
"javascript": f"node {filename}", | |
"nodejs": f"node {filename}", | |
"ruby": f"ruby {filename}", | |
"php": f"php {filename}", | |
"bash": f"bash {filename}", | |
"shell": f"sh {filename}", | |
"powershell": f"pwsh {filename}", | |
"c": f"gcc {filename} -o /tmp/prog_{timestamp} && /tmp/prog_{timestamp}", | |
"cpp": f"g++ {filename} -o /tmp/prog_{timestamp} && /tmp/prog_{timestamp}", | |
"java": f"javac {filename} && java -cp /tmp {os.path.basename(filename).split('.')[0]}", | |
"go": f"go run {filename}", | |
"rust": f"rustc {filename} -o /tmp/prog_{timestamp} && /tmp/prog_{timestamp}", | |
}.get(language.lower(), f"cat {filename}") | |
# Execute the command | |
response = _daytona_sandbox.process.exec(exec_cmd, cwd="/tmp", timeout=timeout) | |
# Extract and return the result | |
if hasattr(response, 'result'): | |
result = response.result | |
elif hasattr(response, 'stdout'): | |
result = response.stdout | |
else: | |
result = str(response) | |
return f"Code Execution Result ({language}):\n{result}" | |
except Exception as e: | |
# Clean up on error | |
try: | |
if _daytona_sandbox is not None: | |
_daytona_sandbox = None | |
except: | |
pass | |
return f"Error executing code: {str(e)}" | |
def execute_shell_command(command: str, working_dir: str = "/tmp", timeout: int = 300) -> str: | |
"""Execute a shell command securely in an isolated sandbox environment using Daytona. | |
This tool runs shell commands in a secure, isolated environment to prevent security issues. | |
It's useful for file operations, system tasks, and other command-line operations. | |
Parameters: | |
- command: The shell command to execute | |
- working_dir: The working directory (default: "/tmp") | |
- timeout: Maximum execution time in seconds (default: 30) | |
Returns: | |
- The command execution output or error message | |
""" | |
global _daytona_sandbox | |
try: | |
# Initialize Daytona client if not already done | |
if _daytona_sandbox is None: | |
api_key = os.getenv("DAYTONA_API_KEY") | |
if not api_key: | |
return "Error: DAYTONA_API_KEY environment variable not set" | |
# Initialize the Daytona client and create a sandbox | |
config = DaytonaConfig(api_key=api_key) | |
daytona_client = Daytona(config) | |
_daytona_sandbox = daytona_client.create() | |
# Execute the command | |
response = _daytona_sandbox.process.exec(command, cwd=working_dir, timeout=timeout) | |
# Extract and return the result | |
if hasattr(response, 'result'): | |
result = response.result | |
elif hasattr(response, 'stdout'): | |
result = response.stdout | |
else: | |
result = str(response) | |
return f"Shell Command Execution Result:\n{result}" | |
except Exception as e: | |
# Clean up on error | |
try: | |
if _daytona_sandbox is not None: | |
_daytona_sandbox = None | |
except: | |
pass | |
return f"Error executing shell command: {str(e)}" | |
def sandbox_file_operation(operation: str, file_path: str, content: str = "", target_path: str = "") -> str: | |
"""Perform file operations in the secure sandbox environment. | |
This tool allows secure file manipulation in an isolated sandbox. | |
It supports creating, reading, writing, moving, copying and deleting files. | |
Parameters: | |
- operation: The operation to perform ('create', 'read', 'write', 'append', 'delete', 'move', 'copy', 'list') | |
- file_path: Path to the file to operate on | |
- content: Content to write (for 'create', 'write', 'append' operations) | |
- target_path: Target path for 'move' and 'copy' operations | |
Returns: | |
- Operation result or file content | |
""" | |
global _daytona_sandbox | |
try: | |
# Initialize Daytona client if not already done | |
if _daytona_sandbox is None: | |
api_key = os.getenv("DAYTONA_API_KEY") | |
if not api_key: | |
return "Error: DAYTONA_API_KEY environment variable not set" | |
# Initialize the Daytona client and create a sandbox | |
config = DaytonaConfig(api_key=api_key) | |
daytona_client = Daytona(config) | |
_daytona_sandbox = daytona_client.create() | |
# Perform the requested operation | |
operation = operation.lower() | |
if operation == "create" or operation == "write": | |
# Create or overwrite file | |
_daytona_sandbox.fs.upload_file(file_path, content.encode('utf-8')) | |
return f"File {file_path} created/written successfully" | |
elif operation == "append": | |
# First try to read the existing content | |
try: | |
existing_content = _daytona_sandbox.fs.download_file(file_path).decode('utf-8') | |
except: | |
existing_content = "" | |
# Append new content and write back | |
new_content = existing_content + content | |
_daytona_sandbox.fs.upload_file(file_path, new_content.encode('utf-8')) | |
return f"Content appended to {file_path} successfully" | |
elif operation == "read": | |
# Read file content | |
try: | |
content = _daytona_sandbox.fs.download_file(file_path).decode('utf-8') | |
return f"Content of {file_path}:\n{content}" | |
except Exception as e: | |
return f"Error reading {file_path}: {str(e)}" | |
elif operation == "delete": | |
# Delete file | |
response = _daytona_sandbox.process.exec(f"rm -f {file_path}", cwd="/tmp") | |
return f"File {file_path} deleted" | |
elif operation == "move": | |
# Move file | |
if not target_path: | |
return "Error: Target path required for move operation" | |
response = _daytona_sandbox.process.exec(f"mv {file_path} {target_path}", cwd="/tmp") | |
return f"File moved from {file_path} to {target_path}" | |
elif operation == "copy": | |
# Copy file | |
if not target_path: | |
return "Error: Target path required for copy operation" | |
response = _daytona_sandbox.process.exec(f"cp {file_path} {target_path}", cwd="/tmp") | |
return f"File copied from {file_path} to {target_path}" | |
elif operation == "list": | |
# List directory contents | |
response = _daytona_sandbox.process.exec(f"ls -la {file_path}", cwd="/tmp") | |
if hasattr(response, 'result'): | |
result = response.result | |
elif hasattr(response, 'stdout'): | |
result = response.stdout | |
else: | |
result = str(response) | |
return f"Directory listing of {file_path}:\n{result}" | |
else: | |
return f"Unsupported operation: {operation}" | |
except Exception as e: | |
return f"Error performing file operation: {str(e)}" | |
def cleanup_daytona_sandbox(): | |
"""Clean up the Daytona sandbox when it's no longer needed.""" | |
global _daytona_sandbox | |
try: | |
if _daytona_sandbox is not None: | |
# Get the Daytona client | |
api_key = os.getenv("DAYTONA_API_KEY") | |
if api_key: | |
config = DaytonaConfig(api_key=api_key) | |
daytona_client = Daytona(config) | |
# Remove the sandbox | |
daytona_client.remove(_daytona_sandbox) | |
_daytona_sandbox = None | |
print("Daytona sandbox cleaned up successfully") | |
except Exception as e: | |
print(f"Error cleaning up Daytona sandbox: {str(e)}") | |
# Track last execution time for rate limiting | |
_last_extract_url_time = 0 | |
def extract_document_data(input_method: str, files: list, prompt: str, json_mode: bool = False) -> str: | |
"""Extract structured data from documents using Dumpling AI. | |
This tool allows you to extract information from various document formats including PDFs, | |
Office documents, images, and many other file types. It uses vision-capable Large Language | |
Models (LLMs) to interpret and extract data based on your specific prompt. | |
Parameters: | |
- input_method: How to input files, either "url" or "base64" | |
- files: List of file URLs or base64-encoded strings depending on input_method | |
- prompt: Specific instructions for what data to extract from the document | |
- json_mode: Whether to return structured JSON (true) or free text (false) | |
Returns: | |
- Extracted data from the document based on your prompt | |
Supported file extensions include PDFs, Word docs, Excel files, PowerPoint, images, HTML, and many others. | |
""" | |
api_key = os.getenv("DUMPLING_API_KEY") | |
if not api_key: | |
return "Error: DUMPLING_API_KEY environment variable not set" | |
try: | |
url = "https://app.dumplingai.com/api/v1/extract-document" | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {api_key}" | |
} | |
data = { | |
"inputMethod": input_method, | |
"files": files, | |
"prompt": prompt, | |
"jsonMode": json_mode | |
} | |
response = requests.post(url, headers=headers, json=data, timeout=120) | |
response.raise_for_status() | |
result = response.json() | |
# Format the response in a readable way | |
formatted_response = f"Document Extraction Results:\n\n" | |
formatted_response += f"Extracted Data:\n{result.get('results', 'No results found')}\n\n" | |
formatted_response += f"Pages Processed: {result.get('pages', 'Unknown')}\n" | |
formatted_response += f"Files Processed: {result.get('fileCount', 'Unknown')}\n" | |
formatted_response += f"Credit Usage: {result.get('creditUsage', 'Unknown')}\n" | |
return formatted_response | |
except requests.exceptions.Timeout: | |
return "Error: Request to Dumpling AI API timed out after 120 seconds" | |
except requests.exceptions.HTTPError as e: | |
error_detail = f"HTTP Error: {e.response.status_code}" | |
try: | |
error_json = e.response.json() | |
error_detail += f" - {error_json.get('detail', error_json)}" | |
except: | |
error_detail += f" - {e.response.text[:500]}" | |
return error_detail | |
except requests.exceptions.RequestException as e: | |
return f"Error making request to Dumpling AI API: {str(e)}" | |
except Exception as e: | |
return f"Error extracting document data: {str(e)}" | |
def extract_image_data(input_method: str, images: list, prompt: str, json_mode: bool = False) -> str: | |
"""Extract visual information from images using Dumpling AI. | |
This tool allows you to extract detailed descriptions or specific information from images | |
using vision-capable Large Language Models (LLMs). It can identify objects, scenes, text, | |
and other visual elements based on your specific prompt. | |
Parameters: | |
- input_method: How to input images, either "url" or "base64" | |
- images: List of image URLs or base64-encoded strings depending on input_method | |
- prompt: Specific instructions for what information to extract from the image | |
- json_mode: Whether to return structured JSON (true) or free text (false) | |
Returns: | |
- Extracted visual data from the image based on your prompt | |
""" | |
api_key = os.getenv("DUMPLING_API_KEY") | |
if not api_key: | |
return "Error: DUMPLING_API_KEY environment variable not set" | |
try: | |
url = "https://app.dumplingai.com/api/v1/extract-image" | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {api_key}" | |
} | |
data = { | |
"inputMethod": input_method, | |
"images": images, | |
"prompt": prompt, | |
"jsonMode": json_mode | |
} | |
response = requests.post(url, headers=headers, json=data, timeout=120) | |
response.raise_for_status() | |
result = response.json() | |
# Format the response in a readable way | |
formatted_response = f"Image Analysis Results:\n\n" | |
formatted_response += f"Extracted Data:\n{result.get('results', 'No results found')}\n\n" | |
formatted_response += f"Images Processed: {result.get('imageCount', 'Unknown')}\n" | |
formatted_response += f"Credit Usage: {result.get('creditUsage', 'Unknown')}\n" | |
return formatted_response | |
except requests.exceptions.Timeout: | |
return "Error: Request to Dumpling AI API timed out after 120 seconds" | |
except requests.exceptions.HTTPError as e: | |
error_detail = f"HTTP Error: {e.response.status_code}" | |
try: | |
error_json = e.response.json() | |
error_detail += f" - {error_json.get('detail', error_json)}" | |
except: | |
error_detail += f" - {e.response.text[:500]}" | |
return error_detail | |
except requests.exceptions.RequestException as e: | |
return f"Error making request to Dumpling AI API: {str(e)}" | |
except Exception as e: | |
return f"Error extracting image data: {str(e)}" | |
def extract_url_content(url: str) -> str: | |
"""Extract content from a URL using Diffbot API (supports webpages, articles, PDFs, etc.). | |
This function is rate-limited to execute no more frequently than once every 20 seconds.""" | |
global _last_extract_url_time | |
# Check if we need to wait before executing | |
current_time = time.time() | |
time_since_last_call = current_time - _last_extract_url_time | |
if time_since_last_call < 20 and _last_extract_url_time > 0: | |
# Calculate how long to wait | |
wait_time = 20 - time_since_last_call | |
print(f"Rate limiting: waiting {wait_time:.2f} seconds before next API call") | |
time.sleep(wait_time) | |
current_time = time.time() # Update current time after sleeping | |
# Update last execution time | |
_last_extract_url_time = current_time | |
# Diffbot token from environment or use the fallback | |
token = os.getenv("DIFFBOT_TOKEN") | |
if not token: | |
return "Error: DIFFBOT_TOKEN environment variable not set" | |
# Set up the API endpoint | |
api_url = "https://api.diffbot.com/v3/article" | |
# Parameters for the request | |
params = { | |
"token": token, | |
"url": url | |
} | |
try: | |
# Make the API request with a timeout | |
response = requests.get(api_url, params=params, timeout=60) # 30 second timeout | |
response.raise_for_status() # Raise exception for HTTP errors | |
# Parse the response | |
data = response.json() | |
# Extract relevant information | |
if "objects" in data and len(data["objects"]) > 0: | |
obj = data["objects"][0] | |
# Create a formatted result | |
result = f"Title: {obj.get('title', 'No title')}\n\n" | |
if "text" in obj: | |
result += f"Content:\n{obj.get('text')}\n\n" | |
#if "html" in obj: | |
# result += f"HTML Content:\n{obj.get('html')}\n\n" | |
if "categories" in obj and obj["categories"]: | |
categories = ", ".join([f"{cat.get('name')} ({cat.get('score', 0):.2f})" | |
for cat in obj["categories"]]) | |
result += f"Categories: {categories}\n" | |
result += f"Source: {obj.get('siteName', 'Unknown')}\n" | |
result += f"URL: {obj.get('pageUrl', url)}" | |
return result | |
else: | |
return f"No content could be extracted from {url}. Response: {data}" | |
except requests.exceptions.Timeout: | |
return f"Error: Request to extract content from {url} timed out after 30 seconds" | |
except requests.exceptions.RequestException as e: | |
return f"Error: Failed to extract content from {url}: {str(e)}" | |
except Exception as e: | |
return f"Error extracting content from {url}: {str(e)}" | |
def get_youtube_transcript(url: str) -> str: | |
"""Get the transcript (captions) from a YouTube video as text. | |
This tool extracts the transcript text from YouTube videos, returns the transcript as a string. | |
Parameters: | |
- url: The YouTube video URL | |
Returns: | |
- The transcript as a string, or an error message if the transcript couldn't be obtained | |
""" | |
# Create a temporary directory to store subtitle files | |
temp_dir = tempfile.mkdtemp() | |
current_dir = os.getcwd() | |
try: | |
# Change to temp directory for download | |
os.chdir(temp_dir) | |
ydl_opts = { | |
'writesubtitles': True, # Download subtitles | |
'writeautomaticsub': True, # Download automatic subtitles | |
'subtitleslangs': ['en'], # Specify English language | |
'skip_download': True, # Skip downloading the video, only get subtitles | |
'outtmpl': 'subtitle', # Simple output template | |
} | |
# Download the subtitles | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
info_dict = ydl.extract_info(url, download=True) | |
video_title = info_dict.get('title', 'Unknown Title') | |
# Look for subtitle files in the temp directory | |
subtitle_content = "" | |
subtitle_files = list(Path(temp_dir).glob("*.vtt")) + list(Path(temp_dir).glob("*.srt")) | |
if subtitle_files: | |
# Read the first subtitle file found | |
with open(subtitle_files[0], 'r', encoding='utf-8') as f: | |
subtitle_content = f.read() | |
# Clean up the subtitle content to remove timestamps and formatting | |
# This is a simple cleaning - more complex parsing may be needed for perfect results | |
lines = subtitle_content.split('\n') | |
cleaned_lines = [] | |
for line in lines: | |
# Skip time codes, numbering and empty lines | |
if line.strip() and not line.strip().isdigit() and not '-->' in line and not line.startswith('WEBVTT'): | |
cleaned_lines.append(line) | |
subtitle_content = ' '.join(cleaned_lines) | |
return f"Transcript from YouTube video: '{video_title}'\n\n{subtitle_content}" | |
else: | |
return f"No transcript found for YouTube video: '{video_title}'" | |
except Exception as e: | |
return f"Error retrieving YouTube transcript: {str(e)}" | |
finally: | |
# Change back to original directory and clean up | |
os.chdir(current_dir) | |
# Cleanup files (optional) | |
try: | |
for file in os.listdir(temp_dir): | |
os.remove(os.path.join(temp_dir, file)) | |
os.rmdir(temp_dir) | |
except: | |
pass | |
from litellm import LiteLLMModel | |
import os | |
class BasicAgent: | |
def __init__(self): | |
print("BasicAgent initialized.") | |
# Initialize callback handler | |
self.langfuse_handler = CallbackHandler() | |
# Supervisor model using Gemini | |
self.supervisor_model = LiteLLMModel( | |
model_id="gemini/gemini-2.0-flash-lite", | |
api_key=os.getenv("GEMINI_API_KEY"), | |
temperature=0.5, | |
max_tokens=1024, | |
) | |
# Validator model using Gemini | |
self.validator_model = LiteLLMModel( | |
model_id="gemini/gemini-2.0-flash-lite", | |
api_key=os.getenv("GEMINI_API_KEY"), | |
temperature=0.5, | |
max_tokens=1024, | |
) | |
# Worker base model using Gemini | |
self.worker_model_base = LiteLLMModel( | |
model_id="gemini/gemini-2.0-flash-lite", | |
api_key=os.getenv("GEMINI_API_KEY"), | |
temperature=0.75, | |
max_tokens=20000, | |
) | |
# Bind tools to the worker model | |
self.tools = [search_web_tavily, search_web_serper, execute_code_securely, | |
execute_shell_command, sandbox_file_operation, extract_document_data, | |
extract_image_data, extract_url_content, get_youtube_transcript] | |
self.worker_model = self.worker_model_base.bind_tools(self.tools) | |
# Create the tool node for executing tools | |
self.tool_node = ToolNode(self.tools) | |
# Create the workflow | |
self.app = self._create_workflow() | |
def _process_messages_after_tools(self, messages): | |
"""Process messages to ensure tool calls and tool results are properly paired. | |
This helps prevent the Anthropic error: unexpected `tool_use_id` found in `tool_result` blocks.""" | |
# Create a mapping of tool_call_id to AIMessage index | |
tool_call_map = {} | |
for i, msg in enumerate(messages): | |
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): | |
for tool_call in msg.tool_calls: | |
if "id" in tool_call: | |
tool_call_map[tool_call["id"]] = i | |
# Filter out ToolMessages that don't have a matching AIMessage with tool_calls | |
processed_messages = [] | |
for i, msg in enumerate(messages): | |
if isinstance(msg, ToolMessage) and hasattr(msg, "tool_call_id"): | |
# Only include if there is a matching AIMessage with this tool_call_id | |
if msg.tool_call_id in tool_call_map: | |
ai_msg_index = tool_call_map[msg.tool_call_id] | |
# Make sure this tool message comes right after its AIMessage | |
if i > ai_msg_index and not any( | |
isinstance(messages[j], ToolMessage) and | |
hasattr(messages[j], "tool_call_id") and | |
messages[j].tool_call_id == msg.tool_call_id | |
for j in range(ai_msg_index + 1, i) | |
): | |
processed_messages.append(msg) | |
else: | |
processed_messages.append(msg) | |
return processed_messages | |
def _create_workflow(self): | |
workflow = StateGraph(AgentState) | |
# Add nodes | |
workflow.add_node("supervisor", self._supervisor_agent) | |
workflow.add_node("worker", self._worker_agent) | |
workflow.add_node("tools", self._handle_tools) | |
workflow.add_node("validator", self._validation_agent) | |
# Add edges using the START and END constants | |
workflow.add_edge(START, "supervisor") | |
# All nodes use Command to specify their next destination, so we don't need conditional edges | |
# Each node's Command(goto=...) specifies the next node | |
# Compile the graph | |
return workflow.compile() | |
def _supervisor_agent(self, state: AgentState) -> Command: | |
"""Supervisor agent that coordinates the workflow.""" | |
# Get the question from state | |
current_question = state["current_question"] | |
messages = state["messages"] | |
worker_iterations = state.get("worker_iterations", 0) | |
# If we have messages and this isn't the first iteration, evaluate worker's response | |
if messages and worker_iterations > 0: | |
# Find the last worker response | |
worker_response = None | |
for msg in reversed(messages): | |
if isinstance(msg, AIMessage) and not getattr(msg, "tool_calls", None): | |
worker_response = msg.content | |
break | |
if worker_response: | |
# Evaluate the worker's response | |
eval_prompt = ChatPromptTemplate.from_messages([ | |
("system", """You are a supervisor agent evaluating a worker's research report about user's question. | |
Analyze whether the report with answer completely and accurately answers the question. | |
Your evaluation criteria: | |
- Completeness: Does the answer address all aspects of the question? | |
- Accuracy: Are the facts, references and reasoning correct? | |
- Path clarity: Is the path to the answer logical and well-explained? | |
- Evidence quality: Are the references reliable and directly relevant? | |
Worker has access to search and web content extraction tools, also python code execution tool. | |
Tasks given to You are not casual questions by random humans, but tricky contest puzzles that test LLM capabilities. | |
If all criteria are met, respond with "SATISFIED". | |
If any criteria are not met, respond with "UNSATISFIED: [specific detailed feedback]". | |
Be precise in your feedback so the worker knows exactly what to improve."""), | |
("human", f"Question: {current_question}\nWorker's report with answer: {worker_response}") | |
]) | |
evaluation = self.supervisor_model.invoke(eval_prompt.format_prompt().to_messages()).content | |
# Determine if supervisor is satisfied | |
supervisor_satisfaction = evaluation.startswith("SATISFIED") | |
if supervisor_satisfaction: | |
# If satisfied, prepare to move to validator | |
return Command( | |
goto="validator", | |
update={ | |
"supervisor_satisfaction": True | |
} | |
) | |
else: | |
# If not satisfied, give feedback to worker | |
feedback = evaluation.replace("UNSATISFIED: ", "") | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", """You are a supervisor agent providing targeted feedback to the worker agent. | |
Your role is to guide the worker to improve their research report by: | |
1) Highlighting specific areas that need improvement | |
2) Providing clear, actionable guidance on what additional research is needed | |
3) Explaining exactly how the worker should revise their approach | |
4) Reminding them of any specific formatting requirements in the original question | |
Worker has access to the following tools: | |
- Web search (using Tavily and Serper) | |
- Web content extraction | |
- Image analysis (can extract visual information from images) | |
- Document data extraction (from PDFs, documents, etc.) | |
- Secure code execution (for Python and other languages) | |
- Secure shell command execution | |
- Secure file operations | |
For computational puzzles, math problems, data processing, or tasks requiring exact precision, | |
recommend using the code execution tools rather than relying on reasoning alone. | |
Tasks given to You are not casual questions by random humans, but tricky contest puzzles that test LLM capabilities. | |
Focus on being constructive and precise. The worker should understand exactly what to do next."""), | |
("human", f"Question: {current_question}\nWorker's current response: {worker_response}\nImprovement needed: {feedback}") | |
]) | |
feedback_message = self.supervisor_model.invoke(prompt.format_prompt().to_messages()).content | |
# Update messages with feedback and increment worker iterations | |
return Command( | |
goto="worker", | |
update={ | |
"messages": messages + [HumanMessage(content=feedback_message)], | |
"worker_iterations": worker_iterations + 1, | |
"supervisor_satisfaction": False | |
} | |
) | |
# First iteration, provide initial instructions | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", """You are a supervisor agent responsible for coordinating a research workflow. | |
Your responsibilities: | |
1) Analyze the question to identify required knowledge, tools, and research strategy | |
2) Provide clear, specific instructions to the worker agent | |
3) Specify exactly what information to gather and what analysis to perform | |
The worker will prepare a concise research report containing: | |
1) Their research path - the logical sequence of steps taken to reach the answer | |
2) The specific references used with clear citations | |
3) A proposed final answer formatted EXACTLY as requested in the question in separate section | |
Worker has access to the following powerful tools: | |
- Web search (using Tavily and Serper) | |
- Web content extraction | |
- Image analysis (can extract visual information from images) | |
- Document data extraction (can extract data from PDFs, documents, etc.) | |
- Secure code execution (for Python and other languages) | |
- Secure shell command execution | |
- Secure file operations | |
You must understand LLM limitations of solving puzzles that can be solved only by code execution, | |
for example math problems, word character flipping, counting and similar tasks that typically plain LLM will fail at. | |
In case of such tasks, worker should use the code execution tools to solve the puzzle. | |
Tasks given to You are not casual questions by random humans, but tricky contest puzzles that test LLM capabilities. | |
Worker should give You full report with all sections for You to evaluate.""" | |
), | |
("human", current_question) | |
]) | |
response = self.supervisor_model.invoke(prompt.format_prompt().to_messages()).content | |
# Use Command pattern to update state and move to worker | |
return Command( | |
goto="worker", | |
update={ | |
"messages": [HumanMessage(content=current_question), AIMessage(content=response)], | |
"worker_iterations": 1, | |
"supervisor_satisfaction": False | |
} | |
) | |
def _worker_agent(self, state: AgentState) -> Command: | |
"""Worker agent that performs the actual work using tools when needed.""" | |
messages = state["messages"] | |
# Process messages to ensure proper tool call-result pairing | |
processed_messages = self._process_messages_after_tools(messages) | |
# Filter out any ToolMessages that don't have a corresponding AIMessage with tool_calls | |
# This helps prevent the "unexpected tool_use_id" error with Anthropic | |
filtered_messages = [] | |
tool_call_ids = set() | |
# First pass: collect all tool_call_ids from AIMessages | |
for msg in processed_messages: | |
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): | |
for tool_call in msg.tool_calls: | |
if "id" in tool_call: | |
tool_call_ids.add(tool_call["id"]) | |
# Second pass: only include ToolMessages that have a corresponding tool_call_id | |
for msg in processed_messages: | |
if isinstance(msg, ToolMessage) and getattr(msg, "tool_call_id", None): | |
if msg.tool_call_id in tool_call_ids: | |
filtered_messages.append(msg) | |
else: | |
filtered_messages.append(msg) | |
# If messages exist, use them directly with the tool-enabled model | |
response = self.worker_model.invoke(filtered_messages) | |
# Update messages - add the response to the original messages | |
# We don't want to lose the original message history | |
updated_messages = messages + [response] | |
# Determine next step using Command pattern | |
if response.tool_calls: | |
# If tool calls are present, go to tools | |
return Command( | |
goto="tools", | |
update={"messages": updated_messages} | |
) | |
else: | |
# No tool calls, return to supervisor for evaluation | |
return Command( | |
goto="supervisor", | |
update={"messages": updated_messages} | |
) | |
def _validation_agent(self, state: AgentState) -> Command: | |
"""Agent that validates the final answer.""" | |
messages = state["messages"] | |
question = state["current_question"] | |
# Get the final answer from the last message | |
final_answer = "" | |
for msg in reversed(messages): | |
if isinstance(msg, AIMessage) and not getattr(msg, "tool_calls", None): | |
final_answer = msg.content | |
break | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", """You are a quality assurance agent responsible for final verification of research reports and precise formatting of final answers. | |
Your critical responsibilities: | |
1) Verify the factual accuracy and completeness of the report, ensuring you can extract and format the final answer exactly as requested in the question | |
2) Ensure EXACT compliance with any formatting instructions in the question by producing a properly structured final answer | |
Pay extremely close attention to formatting requirements. The user may request: | |
- Only specific parts of information (first/last names, specific data points, numerical values) | |
- Particular ordering (alphabetical, chronological, size-based, relevance-based) | |
- Special formatting (bullet points, numbered lists, specific separators, tables) | |
- Exact text case, spacing, punctuation, or other presentational elements | |
Exact formatting compliance is MANDATORY for this challenge evaluation. Your role is to ensure the final answer meets all specified requirements. | |
If numerical values are requested, ensure they are formatted as numbers, not text. | |
Remember that the worker had access to: | |
- Web search tools | |
- Web content extraction | |
- Image analysis (can extract visual information from images) | |
- Document data extraction (from PDFs, documents, etc.) | |
- Secure code execution | |
- Secure shell commands | |
- Secure file operations | |
For computational or precision-based questions, check if code execution was appropriately used and validate the results. | |
When evaluating the answer: | |
- Check if all required information is present and accurate | |
- Verify that the answer directly addresses the specific question asked | |
- Ensure any numerical values, dates, names, or technical terms are correct | |
- Confirm that the formatting precisely matches what was requested | |
- Do not add units to the final answer if not explicitly requested | |
- Do not use money symbols like in the final answer if not explicitly requested | |
- Dont use comma separators for integers like 1,000,000, just use 1000000 | |
- Answers tend to be as short as possible, so do not add extra data unless explicitly requested | |
If the answer report is correct, format it exactly as asked in the question, and respond with: | |
"APPROVED: [THE PROPERLY FORMATTED ANSWER]" | |
If there are issues with overall answer quality and you cannot format the final answer as requested, respond with: | |
"REJECTED: [DETAILED EXPLANATION OF ISSUES]" | |
Be extremely precise in your evaluation - the success of this task depends on your attention to detail. | |
""" | |
), | |
("human", f"Question: {question}\nReport to validate: {final_answer}") | |
]) | |
validation_result = self.validator_model.invoke(prompt.format_prompt().to_messages()).content | |
validator_approval = validation_result.startswith("APPROVED") | |
if validator_approval: | |
# Approved - end the workflow | |
return Command( | |
goto=END, | |
update={ | |
"final_answer": validation_result[10:], # Remove "APPROVED: " prefix | |
"validation_result": validation_result, | |
"validator_approval": True | |
} | |
) | |
else: | |
# Rejected - restart from supervisor with reset state | |
return Command( | |
goto="supervisor", | |
update={ | |
"messages": [HumanMessage(content=question)], | |
"validation_result": validation_result, | |
"validator_approval": False, | |
"worker_iterations": 0, | |
"supervisor_satisfaction": False | |
} | |
) | |
def _handle_tools(self, state: AgentState) -> Command: | |
"""Custom wrapper around ToolNode to ensure proper message handling.""" | |
# Execute the tool using the tool node | |
tool_result = self.tool_node.invoke(state) | |
# Process the result to ensure proper message ordering | |
if "messages" in tool_result: | |
# Get original messages | |
original_messages = state["messages"] | |
# Get all existing AIMessages with tool calls and their indices | |
ai_indices = {} | |
for i, msg in enumerate(original_messages): | |
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): | |
for tool_call in msg.tool_calls: | |
if "id" in tool_call: | |
ai_indices[tool_call["id"]] = i | |
# Add the new tool messages, ensuring they come right after their corresponding tool call | |
updated_messages = list(original_messages) | |
for msg in tool_result["messages"]: | |
if isinstance(msg, ToolMessage) and hasattr(msg, "tool_call_id"): | |
tool_id = msg.tool_call_id | |
if tool_id in ai_indices: | |
# Insert after the AIMessage with the matching tool call | |
insert_idx = ai_indices[tool_id] + 1 | |
# Move past any existing tool messages for this AI message | |
while insert_idx < len(updated_messages) and \ | |
isinstance(updated_messages[insert_idx], ToolMessage) and \ | |
hasattr(updated_messages[insert_idx], "tool_call_id") and \ | |
updated_messages[insert_idx].tool_call_id != tool_id: | |
insert_idx += 1 | |
updated_messages.insert(insert_idx, msg) | |
# Update subsequent indices | |
for id in ai_indices: | |
if ai_indices[id] >= insert_idx: | |
ai_indices[id] += 1 | |
else: | |
# No matching tool call found, just append | |
updated_messages.append(msg) | |
return Command( | |
goto="worker", | |
update={"messages": updated_messages} | |
) | |
# If no message updates, just return the state | |
return Command( | |
goto="worker", | |
update=tool_result | |
) | |
def __call__(self, question: str) -> str: | |
print(f"Agent received question (first 50 chars): {question[:50]}...") | |
# Initialize the state | |
initial_state = { | |
"messages": [], | |
"current_question": question, | |
"final_answer": "", | |
"validation_result": "", | |
"worker_iterations": 0, | |
"supervisor_satisfaction": False, | |
"validator_approval": False | |
} | |
try: | |
# Run the workflow | |
final_state = self.app.invoke(initial_state, config={"callbacks": [self.langfuse_handler], "recursion_limit": 35}) | |
# Return the final answer | |
answer = final_state.get("final_answer", "") | |
if not answer and final_state["messages"]: | |
for msg in reversed(final_state["messages"]): | |
if isinstance(msg, AIMessage) and not getattr(msg, "tool_calls", None): | |
answer = msg.content | |
break | |
print(f"Agent returning answer: {answer[:50]}...") | |
return answer | |
except Exception as e: | |
print(f"Error in agent processing: {str(e)}") | |
# Fallback to basic workflow without tool calls if there's an error | |
return f"I encountered an error while processing your question: {str(e)}. Please try reformulating your question." | |
finally: | |
# Clean up resources | |
cleanup_daytona_sandbox() |