GAIA_Agent / agents /figure_interpretation_agent.py
Aedelon's picture
agent enhancement (#3)
b8f6b7f verified
raw
history blame
14.7 kB
import os
import logging
from llama_index.core.agent.workflow import ReActAgent
from llama_index.core.schema import ImageDocument
from llama_index.core.tools import FunctionTool
from llama_index.llms.google_genai import GoogleGenAI
# Setup logging
logger = logging.getLogger(__name__)
# Helper function to load prompt from file
def load_prompt_from_file(filename: str, default_prompt: str) -> str:
"""Loads a prompt from a text file."""
try:
script_dir = os.path.dirname(__file__)
prompt_path = os.path.join(script_dir, filename)
with open(prompt_path, "r") as f:
prompt = f.read()
logger.info(f"Successfully loaded prompt from {prompt_path}")
return prompt
except FileNotFoundError:
logger.warning(f"Prompt file {filename} not found at {prompt_path}. Using default.")
return default_prompt
except Exception as e:
logger.error(f"Error loading prompt file {filename}: {e}", exc_info=True)
return default_prompt
# --- Core Figure Interpretation Logic (using Multi-Modal LLM) ---
def interpret_figure_with_llm(image_path: str, request: str) -> str:
"""Interprets a figure in an image based on a specific request using a multi-modal LLM.
Args:
image_path (str): Path to the image file containing the figure.
request (str): The specific question or interpretation task (e.g., "Describe this chart",
"Extract sales for Q3", "Identify the main trend").
Returns:
str: The interpretation result or an error message.
"""
logger.info(f"Interpreting figure in image: {image_path} with request: {request}")
# Check if image exists
if not os.path.exists(image_path):
logger.error(f"Image file not found: {image_path}")
return f"Error: Image file not found at {image_path}"
# LLM configuration (Must be a multi-modal model)
# Ensure the selected model supports image input (e.g., gemini-1.5-pro)
llm_model_name = os.getenv("FIGURE_INTERPRETATION_LLM_MODEL", "models/gemini-1.5-pro")
gemini_api_key = os.getenv("GEMINI_API_KEY")
if not gemini_api_key:
logger.error("GEMINI_API_KEY not found for figure interpretation LLM.")
return "Error: GEMINI_API_KEY not set."
try:
# Initialize the multi-modal LLM
llm = GoogleGenAI(api_key=gemini_api_key, model=llm_model_name)
logger.info(f"Using figure interpretation LLM: {llm_model_name}")
# Prepare the prompt for the multi-modal LLM
# The prompt needs to guide the LLM to act as the figure interpreter
# based on the specific request.
prompt = (
f"You are an expert figure interpreter. Analyze the provided image containing a chart, graph, diagram, or table. "
f"Focus *only* on the visual information present in the image. "
f"Fulfill the following request accurately and concisely:\n\n"
f"REQUEST: {request}\n\n"
f"Based *only* on the image, provide the answer:"
)
# Load the image data (LlamaIndex integration might handle this differently depending on version)
# Assuming a method to load image data compatible with the LLM call
# This might involve using ImageBlock or similar structures in newer LlamaIndex versions.
# For simplicity here, we assume the LLM call can handle a path or loaded image object.
# Example using complete (adjust based on actual LlamaIndex multi-modal API)
# Note: The exact API for multi-modal completion might vary.
# This is a conceptual example.
from llama_index.core import SimpleDirectoryReader # Example import
# Load the image document
reader = SimpleDirectoryReader(input_files=[image_path])
image_documents = reader.load_data()
if not image_documents or not isinstance(image_documents[0], ImageDocument):
logger.error(f"Failed to load image as ImageDocument: {image_path}")
return f"Error: Could not load image file {image_path} for analysis."
# Make the multi-modal completion call
response = llm.complete(
prompt=prompt,
image_documents=image_documents # Pass the loaded image document(s)
)
interpretation = response.text.strip()
logger.info("Figure interpretation successful.")
return interpretation
except FileNotFoundError:
# This might be redundant due to the initial check, but good practice
logger.error(f"Image file not found during LLM call: {image_path}")
return f"Error: Image file not found at {image_path}"
except ImportError as ie:
logger.error(f"Missing library for multi-modal processing: {ie}")
return f"Error: Missing required library for image processing ({ie})."
except Exception as e:
# Catch potential API errors or other issues
logger.error(f"LLM call failed during figure interpretation: {e}", exc_info=True)
# Check if the error suggests the model doesn't support images
if "does not support image input" in str(e).lower():
logger.error(f"The configured model {llm_model_name} does not support image input.")
return f"Error: The configured LLM ({llm_model_name}) does not support image input. Please configure a multi-modal model."
return f"Error during figure interpretation: {e}"
# --- Tool Definitions (Wrapping the core logic) ---
# These tools essentially pass the request to the core LLM function.
def describe_figure_tool_fn(image_path: str) -> str:
"Provides a general description of the figure in the image (type, elements, topic)."
return interpret_figure_with_llm(image_path, "Describe this figure, including its type, main elements (axes, labels, legend), and overall topic.")
def extract_data_points_tool_fn(image_path: str, data_request: str) -> str:
"Extracts specific data points or values from the figure in the image."
return interpret_figure_with_llm(image_path, f"Extract the following data points/values from the figure: {data_request}. If exact values are not clear, provide the closest estimate based on the visual.")
def identify_trends_tool_fn(image_path: str) -> str:
"Identifies and describes trends or patterns shown in the figure in the image."
return interpret_figure_with_llm(image_path, "Analyze and describe the main trends or patterns shown in this figure.")
def compare_elements_tool_fn(image_path: str, comparison_request: str) -> str:
"Compares different elements within the figure in the image."
return interpret_figure_with_llm(image_path, f"Compare the following elements within the figure: {comparison_request}. Be specific about the comparison based on the visual data.")
def summarize_figure_insights_tool_fn(image_path: str) -> str:
"Summarizes the key insights or main message conveyed by the figure in the image."
return interpret_figure_with_llm(image_path, "Summarize the key insights or the main message conveyed by this figure.")
# --- Tool Definitions for Agent ---
describe_figure_tool = FunctionTool.from_defaults(
fn=describe_figure_tool_fn,
name="describe_figure",
description="Provides a general description of the figure in the image (type, elements, topic). Input: image_path (str)."
)
extract_data_points_tool = FunctionTool.from_defaults(
fn=extract_data_points_tool_fn,
name="extract_data_points",
description="Extracts specific data points/values from the figure. Input: image_path (str), data_request (str)."
)
identify_trends_tool = FunctionTool.from_defaults(
fn=identify_trends_tool_fn,
name="identify_trends",
description="Identifies and describes trends/patterns in the figure. Input: image_path (str)."
)
compare_elements_tool = FunctionTool.from_defaults(
fn=compare_elements_tool_fn,
name="compare_elements",
description="Compares different elements within the figure. Input: image_path (str), comparison_request (str)."
)
summarize_figure_insights_tool = FunctionTool.from_defaults(
fn=summarize_figure_insights_tool_fn,
name="summarize_figure_insights",
description="Summarizes the key insights/main message of the figure. Input: image_path (str)."
)
# --- Agent Initialization ---
def initialize_figure_interpretation_agent() -> ReActAgent:
"""Initializes the Figure Interpretation Agent."""
logger.info("Initializing FigureInterpretationAgent...")
# Configuration for the agent's main LLM (can be the same multi-modal one)
agent_llm_model = os.getenv("FIGURE_INTERPRETATION_AGENT_LLM_MODEL", "models/gemini-1.5-pro")
gemini_api_key = os.getenv("GEMINI_API_KEY")
if not gemini_api_key:
logger.error("GEMINI_API_KEY not found for FigureInterpretationAgent.")
raise ValueError("GEMINI_API_KEY must be set for FigureInterpretationAgent")
try:
# Agent's LLM doesn't necessarily need to be multi-modal itself,
# if the tools handle the multi-modal calls.
# However, using a multi-modal one might allow more direct interaction patterns later.
llm = GoogleGenAI(api_key=gemini_api_key, model=agent_llm_model)
logger.info(f"Using agent LLM: {agent_llm_model}")
# Load system prompt
default_system_prompt = ("You are FigureInterpretationAgent... [Default prompt content - replace with actual]" # Placeholder
)
system_prompt = load_prompt_from_file("../prompts/figure_interpretation_agent_prompt.txt", default_system_prompt)
if system_prompt == default_system_prompt:
logger.warning("Using default/fallback system prompt for FigureInterpretationAgent.")
# Define available tools
tools = [
describe_figure_tool,
extract_data_points_tool,
identify_trends_tool,
compare_elements_tool,
summarize_figure_insights_tool
]
# Define valid handoff targets
valid_handoffs = [
"planner_agent", # To return results
"research_agent", # If context from figure needs further research
"reasoning_agent" # If interpretation needs logical analysis
]
agent = ReActAgent(
name="figure_interpretation_agent",
description=(
"Analyzes and interprets visual data representations (charts, graphs, tables) from image files. "
"Can describe figures, extract data, identify trends, compare elements, and summarize insights."
),
tools=tools,
llm=llm,
system_prompt=system_prompt,
can_handoff_to=valid_handoffs,
# Note: This agent inherently requires multi-modal input capabilities,
# which are handled within its tools via a multi-modal LLM.
)
logger.info("FigureInterpretationAgent initialized successfully.")
return agent
except Exception as e:
logger.error(f"Error during FigureInterpretationAgent initialization: {e}", exc_info=True)
raise
# Example usage (for testing if run directly)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger.info("Running figure_interpretation_agent.py directly for testing...")
# Check required keys
required_keys = ["GEMINI_API_KEY"]
missing_keys = [key for key in required_keys if not os.getenv(key)]
if missing_keys:
print(f"Error: Required environment variable(s) not set: {', '.join(missing_keys)}. Cannot run test.")
else:
# Check if a multi-modal model is likely configured (heuristic)
model_name = os.getenv("FIGURE_INTERPRETATION_LLM_MODEL", "models/gemini-1.5-pro")
if "pro" not in model_name.lower() and "vision" not in model_name.lower():
print(f"Warning: Configured LLM {model_name} might not support image input. Tests may fail.")
# Create a dummy image file for testing (requires Pillow)
dummy_image_path = "dummy_figure.png"
try:
from PIL import Image, ImageDraw, ImageFont
img = Image.new('RGB', (400, 200), color = (255, 255, 255))
d = ImageDraw.Draw(img)
# Try to load a default font, handle if not found
try:
font = ImageFont.truetype("arial.ttf", 15) # Common font, might not exist
except IOError:
font = ImageFont.load_default()
print("Arial font not found, using default PIL font.")
d.text((10,10), "Simple Bar Chart", fill=(0,0,0), font=font)
d.rectangle([50, 50, 100, 150], fill=(255,0,0)) # Bar 1
d.text((60, 160), "A", fill=(0,0,0), font=font)
d.rectangle([150, 80, 200, 150], fill=(0,0,255)) # Bar 2
d.text((160, 160), "B", fill=(0,0,0), font=font)
img.save(dummy_image_path)
print(f"Created dummy image file: {dummy_image_path}")
# Test the tools directly
print("\nTesting describe_figure...")
desc = describe_figure_tool_fn(dummy_image_path)
print(f"Description: {desc}")
print("\nTesting extract_data_points (qualitative)...")
extract_req = "Height of bar A vs Bar B" # Qualitative request
extract_res = extract_data_points_tool_fn(dummy_image_path, extract_req)
print(f"Extraction Result: {extract_res}")
print("\nTesting compare_elements...")
compare_req = "Compare bar A and bar B"
compare_res = compare_elements_tool_fn(dummy_image_path, compare_req)
print(f"Comparison Result: {compare_res}")
# Clean up dummy image
os.remove(dummy_image_path)
except ImportError:
print("Pillow library not installed. Skipping direct tool tests that require image creation.")
# Optionally, still try initializing the agent
try:
test_agent = initialize_figure_interpretation_agent()
print("\nFigure Interpretation Agent initialized successfully (tool tests skipped).")
except Exception as e:
print(f"Error initializing agent: {e}")
except Exception as e:
print(f"Error during testing: {e}")
if os.path.exists(dummy_image_path):
os.remove(dummy_image_path) # Ensure cleanup on error