Spaces:
Running
Running
import os | |
import logging | |
from llama_index.core.agent.workflow import ReActAgent | |
from llama_index.core.schema import ImageDocument | |
from llama_index.core.tools import FunctionTool | |
from llama_index.llms.google_genai import GoogleGenAI | |
# Setup logging | |
logger = logging.getLogger(__name__) | |
# Helper function to load prompt from file | |
def load_prompt_from_file(filename: str, default_prompt: str) -> str: | |
"""Loads a prompt from a text file.""" | |
try: | |
script_dir = os.path.dirname(__file__) | |
prompt_path = os.path.join(script_dir, filename) | |
with open(prompt_path, "r") as f: | |
prompt = f.read() | |
logger.info(f"Successfully loaded prompt from {prompt_path}") | |
return prompt | |
except FileNotFoundError: | |
logger.warning(f"Prompt file {filename} not found at {prompt_path}. Using default.") | |
return default_prompt | |
except Exception as e: | |
logger.error(f"Error loading prompt file {filename}: {e}", exc_info=True) | |
return default_prompt | |
# --- Core Figure Interpretation Logic (using Multi-Modal LLM) --- | |
def interpret_figure_with_llm(image_path: str, request: str) -> str: | |
"""Interprets a figure in an image based on a specific request using a multi-modal LLM. | |
Args: | |
image_path (str): Path to the image file containing the figure. | |
request (str): The specific question or interpretation task (e.g., "Describe this chart", | |
"Extract sales for Q3", "Identify the main trend"). | |
Returns: | |
str: The interpretation result or an error message. | |
""" | |
logger.info(f"Interpreting figure in image: {image_path} with request: {request}") | |
# Check if image exists | |
if not os.path.exists(image_path): | |
logger.error(f"Image file not found: {image_path}") | |
return f"Error: Image file not found at {image_path}" | |
# LLM configuration (Must be a multi-modal model) | |
# Ensure the selected model supports image input (e.g., gemini-1.5-pro) | |
llm_model_name = os.getenv("FIGURE_INTERPRETATION_LLM_MODEL", "models/gemini-1.5-pro") | |
gemini_api_key = os.getenv("GEMINI_API_KEY") | |
if not gemini_api_key: | |
logger.error("GEMINI_API_KEY not found for figure interpretation LLM.") | |
return "Error: GEMINI_API_KEY not set." | |
try: | |
# Initialize the multi-modal LLM | |
llm = GoogleGenAI(api_key=gemini_api_key, model=llm_model_name) | |
logger.info(f"Using figure interpretation LLM: {llm_model_name}") | |
# Prepare the prompt for the multi-modal LLM | |
# The prompt needs to guide the LLM to act as the figure interpreter | |
# based on the specific request. | |
prompt = ( | |
f"You are an expert figure interpreter. Analyze the provided image containing a chart, graph, diagram, or table. " | |
f"Focus *only* on the visual information present in the image. " | |
f"Fulfill the following request accurately and concisely:\n\n" | |
f"REQUEST: {request}\n\n" | |
f"Based *only* on the image, provide the answer:" | |
) | |
# Load the image data (LlamaIndex integration might handle this differently depending on version) | |
# Assuming a method to load image data compatible with the LLM call | |
# This might involve using ImageBlock or similar structures in newer LlamaIndex versions. | |
# For simplicity here, we assume the LLM call can handle a path or loaded image object. | |
# Example using complete (adjust based on actual LlamaIndex multi-modal API) | |
# Note: The exact API for multi-modal completion might vary. | |
# This is a conceptual example. | |
from llama_index.core import SimpleDirectoryReader # Example import | |
# Load the image document | |
reader = SimpleDirectoryReader(input_files=[image_path]) | |
image_documents = reader.load_data() | |
if not image_documents or not isinstance(image_documents[0], ImageDocument): | |
logger.error(f"Failed to load image as ImageDocument: {image_path}") | |
return f"Error: Could not load image file {image_path} for analysis." | |
# Make the multi-modal completion call | |
response = llm.complete( | |
prompt=prompt, | |
image_documents=image_documents # Pass the loaded image document(s) | |
) | |
interpretation = response.text.strip() | |
logger.info("Figure interpretation successful.") | |
return interpretation | |
except FileNotFoundError: | |
# This might be redundant due to the initial check, but good practice | |
logger.error(f"Image file not found during LLM call: {image_path}") | |
return f"Error: Image file not found at {image_path}" | |
except ImportError as ie: | |
logger.error(f"Missing library for multi-modal processing: {ie}") | |
return f"Error: Missing required library for image processing ({ie})." | |
except Exception as e: | |
# Catch potential API errors or other issues | |
logger.error(f"LLM call failed during figure interpretation: {e}", exc_info=True) | |
# Check if the error suggests the model doesn't support images | |
if "does not support image input" in str(e).lower(): | |
logger.error(f"The configured model {llm_model_name} does not support image input.") | |
return f"Error: The configured LLM ({llm_model_name}) does not support image input. Please configure a multi-modal model." | |
return f"Error during figure interpretation: {e}" | |
# --- Tool Definitions (Wrapping the core logic) --- | |
# These tools essentially pass the request to the core LLM function. | |
def describe_figure_tool_fn(image_path: str) -> str: | |
"Provides a general description of the figure in the image (type, elements, topic)." | |
return interpret_figure_with_llm(image_path, "Describe this figure, including its type, main elements (axes, labels, legend), and overall topic.") | |
def extract_data_points_tool_fn(image_path: str, data_request: str) -> str: | |
"Extracts specific data points or values from the figure in the image." | |
return interpret_figure_with_llm(image_path, f"Extract the following data points/values from the figure: {data_request}. If exact values are not clear, provide the closest estimate based on the visual.") | |
def identify_trends_tool_fn(image_path: str) -> str: | |
"Identifies and describes trends or patterns shown in the figure in the image." | |
return interpret_figure_with_llm(image_path, "Analyze and describe the main trends or patterns shown in this figure.") | |
def compare_elements_tool_fn(image_path: str, comparison_request: str) -> str: | |
"Compares different elements within the figure in the image." | |
return interpret_figure_with_llm(image_path, f"Compare the following elements within the figure: {comparison_request}. Be specific about the comparison based on the visual data.") | |
def summarize_figure_insights_tool_fn(image_path: str) -> str: | |
"Summarizes the key insights or main message conveyed by the figure in the image." | |
return interpret_figure_with_llm(image_path, "Summarize the key insights or the main message conveyed by this figure.") | |
# --- Tool Definitions for Agent --- | |
describe_figure_tool = FunctionTool.from_defaults( | |
fn=describe_figure_tool_fn, | |
name="describe_figure", | |
description="Provides a general description of the figure in the image (type, elements, topic). Input: image_path (str)." | |
) | |
extract_data_points_tool = FunctionTool.from_defaults( | |
fn=extract_data_points_tool_fn, | |
name="extract_data_points", | |
description="Extracts specific data points/values from the figure. Input: image_path (str), data_request (str)." | |
) | |
identify_trends_tool = FunctionTool.from_defaults( | |
fn=identify_trends_tool_fn, | |
name="identify_trends", | |
description="Identifies and describes trends/patterns in the figure. Input: image_path (str)." | |
) | |
compare_elements_tool = FunctionTool.from_defaults( | |
fn=compare_elements_tool_fn, | |
name="compare_elements", | |
description="Compares different elements within the figure. Input: image_path (str), comparison_request (str)." | |
) | |
summarize_figure_insights_tool = FunctionTool.from_defaults( | |
fn=summarize_figure_insights_tool_fn, | |
name="summarize_figure_insights", | |
description="Summarizes the key insights/main message of the figure. Input: image_path (str)." | |
) | |
# --- Agent Initialization --- | |
def initialize_figure_interpretation_agent() -> ReActAgent: | |
"""Initializes the Figure Interpretation Agent.""" | |
logger.info("Initializing FigureInterpretationAgent...") | |
# Configuration for the agent's main LLM (can be the same multi-modal one) | |
agent_llm_model = os.getenv("FIGURE_INTERPRETATION_AGENT_LLM_MODEL", "models/gemini-1.5-pro") | |
gemini_api_key = os.getenv("GEMINI_API_KEY") | |
if not gemini_api_key: | |
logger.error("GEMINI_API_KEY not found for FigureInterpretationAgent.") | |
raise ValueError("GEMINI_API_KEY must be set for FigureInterpretationAgent") | |
try: | |
# Agent's LLM doesn't necessarily need to be multi-modal itself, | |
# if the tools handle the multi-modal calls. | |
# However, using a multi-modal one might allow more direct interaction patterns later. | |
llm = GoogleGenAI(api_key=gemini_api_key, model=agent_llm_model) | |
logger.info(f"Using agent LLM: {agent_llm_model}") | |
# Load system prompt | |
default_system_prompt = ("You are FigureInterpretationAgent... [Default prompt content - replace with actual]" # Placeholder | |
) | |
system_prompt = load_prompt_from_file("../prompts/figure_interpretation_agent_prompt.txt", default_system_prompt) | |
if system_prompt == default_system_prompt: | |
logger.warning("Using default/fallback system prompt for FigureInterpretationAgent.") | |
# Define available tools | |
tools = [ | |
describe_figure_tool, | |
extract_data_points_tool, | |
identify_trends_tool, | |
compare_elements_tool, | |
summarize_figure_insights_tool | |
] | |
# Define valid handoff targets | |
valid_handoffs = [ | |
"planner_agent", # To return results | |
"research_agent", # If context from figure needs further research | |
"reasoning_agent" # If interpretation needs logical analysis | |
] | |
agent = ReActAgent( | |
name="figure_interpretation_agent", | |
description=( | |
"Analyzes and interprets visual data representations (charts, graphs, tables) from image files. " | |
"Can describe figures, extract data, identify trends, compare elements, and summarize insights." | |
), | |
tools=tools, | |
llm=llm, | |
system_prompt=system_prompt, | |
can_handoff_to=valid_handoffs, | |
# Note: This agent inherently requires multi-modal input capabilities, | |
# which are handled within its tools via a multi-modal LLM. | |
) | |
logger.info("FigureInterpretationAgent initialized successfully.") | |
return agent | |
except Exception as e: | |
logger.error(f"Error during FigureInterpretationAgent initialization: {e}", exc_info=True) | |
raise | |
# Example usage (for testing if run directly) | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger.info("Running figure_interpretation_agent.py directly for testing...") | |
# Check required keys | |
required_keys = ["GEMINI_API_KEY"] | |
missing_keys = [key for key in required_keys if not os.getenv(key)] | |
if missing_keys: | |
print(f"Error: Required environment variable(s) not set: {', '.join(missing_keys)}. Cannot run test.") | |
else: | |
# Check if a multi-modal model is likely configured (heuristic) | |
model_name = os.getenv("FIGURE_INTERPRETATION_LLM_MODEL", "models/gemini-1.5-pro") | |
if "pro" not in model_name.lower() and "vision" not in model_name.lower(): | |
print(f"Warning: Configured LLM {model_name} might not support image input. Tests may fail.") | |
# Create a dummy image file for testing (requires Pillow) | |
dummy_image_path = "dummy_figure.png" | |
try: | |
from PIL import Image, ImageDraw, ImageFont | |
img = Image.new('RGB', (400, 200), color = (255, 255, 255)) | |
d = ImageDraw.Draw(img) | |
# Try to load a default font, handle if not found | |
try: | |
font = ImageFont.truetype("arial.ttf", 15) # Common font, might not exist | |
except IOError: | |
font = ImageFont.load_default() | |
print("Arial font not found, using default PIL font.") | |
d.text((10,10), "Simple Bar Chart", fill=(0,0,0), font=font) | |
d.rectangle([50, 50, 100, 150], fill=(255,0,0)) # Bar 1 | |
d.text((60, 160), "A", fill=(0,0,0), font=font) | |
d.rectangle([150, 80, 200, 150], fill=(0,0,255)) # Bar 2 | |
d.text((160, 160), "B", fill=(0,0,0), font=font) | |
img.save(dummy_image_path) | |
print(f"Created dummy image file: {dummy_image_path}") | |
# Test the tools directly | |
print("\nTesting describe_figure...") | |
desc = describe_figure_tool_fn(dummy_image_path) | |
print(f"Description: {desc}") | |
print("\nTesting extract_data_points (qualitative)...") | |
extract_req = "Height of bar A vs Bar B" # Qualitative request | |
extract_res = extract_data_points_tool_fn(dummy_image_path, extract_req) | |
print(f"Extraction Result: {extract_res}") | |
print("\nTesting compare_elements...") | |
compare_req = "Compare bar A and bar B" | |
compare_res = compare_elements_tool_fn(dummy_image_path, compare_req) | |
print(f"Comparison Result: {compare_res}") | |
# Clean up dummy image | |
os.remove(dummy_image_path) | |
except ImportError: | |
print("Pillow library not installed. Skipping direct tool tests that require image creation.") | |
# Optionally, still try initializing the agent | |
try: | |
test_agent = initialize_figure_interpretation_agent() | |
print("\nFigure Interpretation Agent initialized successfully (tool tests skipped).") | |
except Exception as e: | |
print(f"Error initializing agent: {e}") | |
except Exception as e: | |
print(f"Error during testing: {e}") | |
if os.path.exists(dummy_image_path): | |
os.remove(dummy_image_path) # Ensure cleanup on error | |