Spaces:
Sleeping
Sleeping
File size: 14,689 Bytes
a23082c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 |
import os
import logging
from llama_index.core.agent.workflow import ReActAgent
from llama_index.core.schema import ImageDocument
from llama_index.core.tools import FunctionTool
from llama_index.llms.google_genai import GoogleGenAI
# Setup logging
logger = logging.getLogger(__name__)
# Helper function to load prompt from file
def load_prompt_from_file(filename: str, default_prompt: str) -> str:
"""Loads a prompt from a text file."""
try:
script_dir = os.path.dirname(__file__)
prompt_path = os.path.join(script_dir, filename)
with open(prompt_path, "r") as f:
prompt = f.read()
logger.info(f"Successfully loaded prompt from {prompt_path}")
return prompt
except FileNotFoundError:
logger.warning(f"Prompt file {filename} not found at {prompt_path}. Using default.")
return default_prompt
except Exception as e:
logger.error(f"Error loading prompt file {filename}: {e}", exc_info=True)
return default_prompt
# --- Core Figure Interpretation Logic (using Multi-Modal LLM) ---
def interpret_figure_with_llm(image_path: str, request: str) -> str:
"""Interprets a figure in an image based on a specific request using a multi-modal LLM.
Args:
image_path (str): Path to the image file containing the figure.
request (str): The specific question or interpretation task (e.g., "Describe this chart",
"Extract sales for Q3", "Identify the main trend").
Returns:
str: The interpretation result or an error message.
"""
logger.info(f"Interpreting figure in image: {image_path} with request: {request}")
# Check if image exists
if not os.path.exists(image_path):
logger.error(f"Image file not found: {image_path}")
return f"Error: Image file not found at {image_path}"
# LLM configuration (Must be a multi-modal model)
# Ensure the selected model supports image input (e.g., gemini-1.5-pro)
llm_model_name = os.getenv("FIGURE_INTERPRETATION_LLM_MODEL", "models/gemini-1.5-pro")
gemini_api_key = os.getenv("GEMINI_API_KEY")
if not gemini_api_key:
logger.error("GEMINI_API_KEY not found for figure interpretation LLM.")
return "Error: GEMINI_API_KEY not set."
try:
# Initialize the multi-modal LLM
llm = GoogleGenAI(api_key=gemini_api_key, model=llm_model_name)
logger.info(f"Using figure interpretation LLM: {llm_model_name}")
# Prepare the prompt for the multi-modal LLM
# The prompt needs to guide the LLM to act as the figure interpreter
# based on the specific request.
prompt = (
f"You are an expert figure interpreter. Analyze the provided image containing a chart, graph, diagram, or table. "
f"Focus *only* on the visual information present in the image. "
f"Fulfill the following request accurately and concisely:\n\n"
f"REQUEST: {request}\n\n"
f"Based *only* on the image, provide the answer:"
)
# Load the image data (LlamaIndex integration might handle this differently depending on version)
# Assuming a method to load image data compatible with the LLM call
# This might involve using ImageBlock or similar structures in newer LlamaIndex versions.
# For simplicity here, we assume the LLM call can handle a path or loaded image object.
# Example using complete (adjust based on actual LlamaIndex multi-modal API)
# Note: The exact API for multi-modal completion might vary.
# This is a conceptual example.
from llama_index.core import SimpleDirectoryReader # Example import
# Load the image document
reader = SimpleDirectoryReader(input_files=[image_path])
image_documents = reader.load_data()
if not image_documents or not isinstance(image_documents[0], ImageDocument):
logger.error(f"Failed to load image as ImageDocument: {image_path}")
return f"Error: Could not load image file {image_path} for analysis."
# Make the multi-modal completion call
response = llm.complete(
prompt=prompt,
image_documents=image_documents # Pass the loaded image document(s)
)
interpretation = response.text.strip()
logger.info("Figure interpretation successful.")
return interpretation
except FileNotFoundError:
# This might be redundant due to the initial check, but good practice
logger.error(f"Image file not found during LLM call: {image_path}")
return f"Error: Image file not found at {image_path}"
except ImportError as ie:
logger.error(f"Missing library for multi-modal processing: {ie}")
return f"Error: Missing required library for image processing ({ie})."
except Exception as e:
# Catch potential API errors or other issues
logger.error(f"LLM call failed during figure interpretation: {e}", exc_info=True)
# Check if the error suggests the model doesn't support images
if "does not support image input" in str(e).lower():
logger.error(f"The configured model {llm_model_name} does not support image input.")
return f"Error: The configured LLM ({llm_model_name}) does not support image input. Please configure a multi-modal model."
return f"Error during figure interpretation: {e}"
# --- Tool Definitions (Wrapping the core logic) ---
# These tools essentially pass the request to the core LLM function.
def describe_figure_tool_fn(image_path: str) -> str:
"Provides a general description of the figure in the image (type, elements, topic)."
return interpret_figure_with_llm(image_path, "Describe this figure, including its type, main elements (axes, labels, legend), and overall topic.")
def extract_data_points_tool_fn(image_path: str, data_request: str) -> str:
"Extracts specific data points or values from the figure in the image."
return interpret_figure_with_llm(image_path, f"Extract the following data points/values from the figure: {data_request}. If exact values are not clear, provide the closest estimate based on the visual.")
def identify_trends_tool_fn(image_path: str) -> str:
"Identifies and describes trends or patterns shown in the figure in the image."
return interpret_figure_with_llm(image_path, "Analyze and describe the main trends or patterns shown in this figure.")
def compare_elements_tool_fn(image_path: str, comparison_request: str) -> str:
"Compares different elements within the figure in the image."
return interpret_figure_with_llm(image_path, f"Compare the following elements within the figure: {comparison_request}. Be specific about the comparison based on the visual data.")
def summarize_figure_insights_tool_fn(image_path: str) -> str:
"Summarizes the key insights or main message conveyed by the figure in the image."
return interpret_figure_with_llm(image_path, "Summarize the key insights or the main message conveyed by this figure.")
# --- Tool Definitions for Agent ---
describe_figure_tool = FunctionTool.from_defaults(
fn=describe_figure_tool_fn,
name="describe_figure",
description="Provides a general description of the figure in the image (type, elements, topic). Input: image_path (str)."
)
extract_data_points_tool = FunctionTool.from_defaults(
fn=extract_data_points_tool_fn,
name="extract_data_points",
description="Extracts specific data points/values from the figure. Input: image_path (str), data_request (str)."
)
identify_trends_tool = FunctionTool.from_defaults(
fn=identify_trends_tool_fn,
name="identify_trends",
description="Identifies and describes trends/patterns in the figure. Input: image_path (str)."
)
compare_elements_tool = FunctionTool.from_defaults(
fn=compare_elements_tool_fn,
name="compare_elements",
description="Compares different elements within the figure. Input: image_path (str), comparison_request (str)."
)
summarize_figure_insights_tool = FunctionTool.from_defaults(
fn=summarize_figure_insights_tool_fn,
name="summarize_figure_insights",
description="Summarizes the key insights/main message of the figure. Input: image_path (str)."
)
# --- Agent Initialization ---
def initialize_figure_interpretation_agent() -> ReActAgent:
"""Initializes the Figure Interpretation Agent."""
logger.info("Initializing FigureInterpretationAgent...")
# Configuration for the agent's main LLM (can be the same multi-modal one)
agent_llm_model = os.getenv("FIGURE_INTERPRETATION_AGENT_LLM_MODEL", "models/gemini-1.5-pro")
gemini_api_key = os.getenv("GEMINI_API_KEY")
if not gemini_api_key:
logger.error("GEMINI_API_KEY not found for FigureInterpretationAgent.")
raise ValueError("GEMINI_API_KEY must be set for FigureInterpretationAgent")
try:
# Agent's LLM doesn't necessarily need to be multi-modal itself,
# if the tools handle the multi-modal calls.
# However, using a multi-modal one might allow more direct interaction patterns later.
llm = GoogleGenAI(api_key=gemini_api_key, model=agent_llm_model)
logger.info(f"Using agent LLM: {agent_llm_model}")
# Load system prompt
default_system_prompt = ("You are FigureInterpretationAgent... [Default prompt content - replace with actual]" # Placeholder
)
system_prompt = load_prompt_from_file("../prompts/figure_interpretation_agent_prompt.txt", default_system_prompt)
if system_prompt == default_system_prompt:
logger.warning("Using default/fallback system prompt for FigureInterpretationAgent.")
# Define available tools
tools = [
describe_figure_tool,
extract_data_points_tool,
identify_trends_tool,
compare_elements_tool,
summarize_figure_insights_tool
]
# Define valid handoff targets
valid_handoffs = [
"planner_agent", # To return results
"research_agent", # If context from figure needs further research
"reasoning_agent" # If interpretation needs logical analysis
]
agent = ReActAgent(
name="figure_interpretation_agent",
description=(
"Analyzes and interprets visual data representations (charts, graphs, tables) from image files. "
"Can describe figures, extract data, identify trends, compare elements, and summarize insights."
),
tools=tools,
llm=llm,
system_prompt=system_prompt,
can_handoff_to=valid_handoffs,
# Note: This agent inherently requires multi-modal input capabilities,
# which are handled within its tools via a multi-modal LLM.
)
logger.info("FigureInterpretationAgent initialized successfully.")
return agent
except Exception as e:
logger.error(f"Error during FigureInterpretationAgent initialization: {e}", exc_info=True)
raise
# Example usage (for testing if run directly)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger.info("Running figure_interpretation_agent.py directly for testing...")
# Check required keys
required_keys = ["GEMINI_API_KEY"]
missing_keys = [key for key in required_keys if not os.getenv(key)]
if missing_keys:
print(f"Error: Required environment variable(s) not set: {', '.join(missing_keys)}. Cannot run test.")
else:
# Check if a multi-modal model is likely configured (heuristic)
model_name = os.getenv("FIGURE_INTERPRETATION_LLM_MODEL", "models/gemini-1.5-pro")
if "pro" not in model_name.lower() and "vision" not in model_name.lower():
print(f"Warning: Configured LLM {model_name} might not support image input. Tests may fail.")
# Create a dummy image file for testing (requires Pillow)
dummy_image_path = "dummy_figure.png"
try:
from PIL import Image, ImageDraw, ImageFont
img = Image.new('RGB', (400, 200), color = (255, 255, 255))
d = ImageDraw.Draw(img)
# Try to load a default font, handle if not found
try:
font = ImageFont.truetype("arial.ttf", 15) # Common font, might not exist
except IOError:
font = ImageFont.load_default()
print("Arial font not found, using default PIL font.")
d.text((10,10), "Simple Bar Chart", fill=(0,0,0), font=font)
d.rectangle([50, 50, 100, 150], fill=(255,0,0)) # Bar 1
d.text((60, 160), "A", fill=(0,0,0), font=font)
d.rectangle([150, 80, 200, 150], fill=(0,0,255)) # Bar 2
d.text((160, 160), "B", fill=(0,0,0), font=font)
img.save(dummy_image_path)
print(f"Created dummy image file: {dummy_image_path}")
# Test the tools directly
print("\nTesting describe_figure...")
desc = describe_figure_tool_fn(dummy_image_path)
print(f"Description: {desc}")
print("\nTesting extract_data_points (qualitative)...")
extract_req = "Height of bar A vs Bar B" # Qualitative request
extract_res = extract_data_points_tool_fn(dummy_image_path, extract_req)
print(f"Extraction Result: {extract_res}")
print("\nTesting compare_elements...")
compare_req = "Compare bar A and bar B"
compare_res = compare_elements_tool_fn(dummy_image_path, compare_req)
print(f"Comparison Result: {compare_res}")
# Clean up dummy image
os.remove(dummy_image_path)
except ImportError:
print("Pillow library not installed. Skipping direct tool tests that require image creation.")
# Optionally, still try initializing the agent
try:
test_agent = initialize_figure_interpretation_agent()
print("\nFigure Interpretation Agent initialized successfully (tool tests skipped).")
except Exception as e:
print(f"Error initializing agent: {e}")
except Exception as e:
print(f"Error during testing: {e}")
if os.path.exists(dummy_image_path):
os.remove(dummy_image_path) # Ensure cleanup on error
|