File size: 14,689 Bytes
a23082c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
import os
import logging

from llama_index.core.agent.workflow import ReActAgent
from llama_index.core.schema import ImageDocument
from llama_index.core.tools import FunctionTool
from llama_index.llms.google_genai import GoogleGenAI

# Setup logging
logger = logging.getLogger(__name__)

# Helper function to load prompt from file
def load_prompt_from_file(filename: str, default_prompt: str) -> str:
    """Loads a prompt from a text file."""
    try:
        script_dir = os.path.dirname(__file__)
        prompt_path = os.path.join(script_dir, filename)
        with open(prompt_path, "r") as f:
            prompt = f.read()
            logger.info(f"Successfully loaded prompt from {prompt_path}")
            return prompt
    except FileNotFoundError:
        logger.warning(f"Prompt file {filename} not found at {prompt_path}. Using default.")
        return default_prompt
    except Exception as e:
        logger.error(f"Error loading prompt file {filename}: {e}", exc_info=True)
        return default_prompt

# --- Core Figure Interpretation Logic (using Multi-Modal LLM) ---

def interpret_figure_with_llm(image_path: str, request: str) -> str:
    """Interprets a figure in an image based on a specific request using a multi-modal LLM.
       Args:
           image_path (str): Path to the image file containing the figure.
           request (str): The specific question or interpretation task (e.g., "Describe this chart", 
                          "Extract sales for Q3", "Identify the main trend").
       Returns:
           str: The interpretation result or an error message.
    """
    logger.info(f"Interpreting figure in image: {image_path} with request: {request}")

    # Check if image exists
    if not os.path.exists(image_path):
        logger.error(f"Image file not found: {image_path}")
        return f"Error: Image file not found at {image_path}"

    # LLM configuration (Must be a multi-modal model)
    # Ensure the selected model supports image input (e.g., gemini-1.5-pro)
    llm_model_name = os.getenv("FIGURE_INTERPRETATION_LLM_MODEL", "models/gemini-1.5-pro") 
    gemini_api_key = os.getenv("GEMINI_API_KEY")
    if not gemini_api_key:
        logger.error("GEMINI_API_KEY not found for figure interpretation LLM.")
        return "Error: GEMINI_API_KEY not set."

    try:
        # Initialize the multi-modal LLM
        llm = GoogleGenAI(api_key=gemini_api_key, model=llm_model_name)
        logger.info(f"Using figure interpretation LLM: {llm_model_name}")

        # Prepare the prompt for the multi-modal LLM
        # The prompt needs to guide the LLM to act as the figure interpreter
        # based on the specific request.
        prompt = (
            f"You are an expert figure interpreter. Analyze the provided image containing a chart, graph, diagram, or table. "
            f"Focus *only* on the visual information present in the image. "
            f"Fulfill the following request accurately and concisely:\n\n"
            f"REQUEST: {request}\n\n"
            f"Based *only* on the image, provide the answer:"
        )

        # Load the image data (LlamaIndex integration might handle this differently depending on version)
        # Assuming a method to load image data compatible with the LLM call
        # This might involve using ImageBlock or similar structures in newer LlamaIndex versions.
        # For simplicity here, we assume the LLM call can handle a path or loaded image object.
        
        # Example using complete (adjust based on actual LlamaIndex multi-modal API)
        # Note: The exact API for multi-modal completion might vary. 
        # This is a conceptual example.
        from llama_index.core import SimpleDirectoryReader # Example import
        
        # Load the image document
        reader = SimpleDirectoryReader(input_files=[image_path])
        image_documents = reader.load_data()
        
        if not image_documents or not isinstance(image_documents[0], ImageDocument):
             logger.error(f"Failed to load image as ImageDocument: {image_path}")
             return f"Error: Could not load image file {image_path} for analysis."

        # Make the multi-modal completion call
        response = llm.complete(
            prompt=prompt,
            image_documents=image_documents # Pass the loaded image document(s)
        )
        
        interpretation = response.text.strip()
        logger.info("Figure interpretation successful.")
        return interpretation

    except FileNotFoundError:
         # This might be redundant due to the initial check, but good practice
         logger.error(f"Image file not found during LLM call: {image_path}")
         return f"Error: Image file not found at {image_path}"
    except ImportError as ie:
         logger.error(f"Missing library for multi-modal processing: {ie}")
         return f"Error: Missing required library for image processing ({ie})."
    except Exception as e:
        # Catch potential API errors or other issues
        logger.error(f"LLM call failed during figure interpretation: {e}", exc_info=True)
        # Check if the error suggests the model doesn't support images
        if "does not support image input" in str(e).lower():
             logger.error(f"The configured model {llm_model_name} does not support image input.")
             return f"Error: The configured LLM ({llm_model_name}) does not support image input. Please configure a multi-modal model."
        return f"Error during figure interpretation: {e}"

# --- Tool Definitions (Wrapping the core logic) ---
# These tools essentially pass the request to the core LLM function.

def describe_figure_tool_fn(image_path: str) -> str:
    "Provides a general description of the figure in the image (type, elements, topic)."
    return interpret_figure_with_llm(image_path, "Describe this figure, including its type, main elements (axes, labels, legend), and overall topic.")

def extract_data_points_tool_fn(image_path: str, data_request: str) -> str:
    "Extracts specific data points or values from the figure in the image."
    return interpret_figure_with_llm(image_path, f"Extract the following data points/values from the figure: {data_request}. If exact values are not clear, provide the closest estimate based on the visual.")

def identify_trends_tool_fn(image_path: str) -> str:
    "Identifies and describes trends or patterns shown in the figure in the image."
    return interpret_figure_with_llm(image_path, "Analyze and describe the main trends or patterns shown in this figure.")

def compare_elements_tool_fn(image_path: str, comparison_request: str) -> str:
    "Compares different elements within the figure in the image."
    return interpret_figure_with_llm(image_path, f"Compare the following elements within the figure: {comparison_request}. Be specific about the comparison based on the visual data.")

def summarize_figure_insights_tool_fn(image_path: str) -> str:
    "Summarizes the key insights or main message conveyed by the figure in the image."
    return interpret_figure_with_llm(image_path, "Summarize the key insights or the main message conveyed by this figure.")

# --- Tool Definitions for Agent ---
describe_figure_tool = FunctionTool.from_defaults(
    fn=describe_figure_tool_fn,
    name="describe_figure",
    description="Provides a general description of the figure in the image (type, elements, topic). Input: image_path (str)."
)

extract_data_points_tool = FunctionTool.from_defaults(
    fn=extract_data_points_tool_fn,
    name="extract_data_points",
    description="Extracts specific data points/values from the figure. Input: image_path (str), data_request (str)."
)

identify_trends_tool = FunctionTool.from_defaults(
    fn=identify_trends_tool_fn,
    name="identify_trends",
    description="Identifies and describes trends/patterns in the figure. Input: image_path (str)."
)

compare_elements_tool = FunctionTool.from_defaults(
    fn=compare_elements_tool_fn,
    name="compare_elements",
    description="Compares different elements within the figure. Input: image_path (str), comparison_request (str)."
)

summarize_figure_insights_tool = FunctionTool.from_defaults(
    fn=summarize_figure_insights_tool_fn,
    name="summarize_figure_insights",
    description="Summarizes the key insights/main message of the figure. Input: image_path (str)."
)

# --- Agent Initialization ---
def initialize_figure_interpretation_agent() -> ReActAgent:
    """Initializes the Figure Interpretation Agent."""
    logger.info("Initializing FigureInterpretationAgent...")

    # Configuration for the agent's main LLM (can be the same multi-modal one)
    agent_llm_model = os.getenv("FIGURE_INTERPRETATION_AGENT_LLM_MODEL", "models/gemini-1.5-pro")
    gemini_api_key = os.getenv("GEMINI_API_KEY")

    if not gemini_api_key:
        logger.error("GEMINI_API_KEY not found for FigureInterpretationAgent.")
        raise ValueError("GEMINI_API_KEY must be set for FigureInterpretationAgent")

    try:
        # Agent's LLM doesn't necessarily need to be multi-modal itself,
        # if the tools handle the multi-modal calls.
        # However, using a multi-modal one might allow more direct interaction patterns later.
        llm = GoogleGenAI(api_key=gemini_api_key, model=agent_llm_model)
        logger.info(f"Using agent LLM: {agent_llm_model}")

        # Load system prompt
        default_system_prompt = ("You are FigureInterpretationAgent... [Default prompt content - replace with actual]" # Placeholder
                              )
        system_prompt = load_prompt_from_file("../prompts/figure_interpretation_agent_prompt.txt", default_system_prompt)
        if system_prompt == default_system_prompt:
             logger.warning("Using default/fallback system prompt for FigureInterpretationAgent.")

        # Define available tools
        tools = [
            describe_figure_tool,
            extract_data_points_tool,
            identify_trends_tool,
            compare_elements_tool,
            summarize_figure_insights_tool
        ]

        # Define valid handoff targets
        valid_handoffs = [
            "planner_agent", # To return results
            "research_agent", # If context from figure needs further research
            "reasoning_agent" # If interpretation needs logical analysis
        ]

        agent = ReActAgent(
            name="figure_interpretation_agent",
            description=(
                "Analyzes and interprets visual data representations (charts, graphs, tables) from image files. "
                "Can describe figures, extract data, identify trends, compare elements, and summarize insights."
            ),
            tools=tools,
            llm=llm,
            system_prompt=system_prompt,
            can_handoff_to=valid_handoffs,
            # Note: This agent inherently requires multi-modal input capabilities, 
            # which are handled within its tools via a multi-modal LLM.
        )
        logger.info("FigureInterpretationAgent initialized successfully.")
        return agent

    except Exception as e:
        logger.error(f"Error during FigureInterpretationAgent initialization: {e}", exc_info=True)
        raise

# Example usage (for testing if run directly)
if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    logger.info("Running figure_interpretation_agent.py directly for testing...")

    # Check required keys
    required_keys = ["GEMINI_API_KEY"]
    missing_keys = [key for key in required_keys if not os.getenv(key)]
    if missing_keys:
        print(f"Error: Required environment variable(s) not set: {', '.join(missing_keys)}. Cannot run test.")
    else:
        # Check if a multi-modal model is likely configured (heuristic)
        model_name = os.getenv("FIGURE_INTERPRETATION_LLM_MODEL", "models/gemini-1.5-pro")
        if "pro" not in model_name.lower() and "vision" not in model_name.lower():
             print(f"Warning: Configured LLM {model_name} might not support image input. Tests may fail.")
             
        # Create a dummy image file for testing (requires Pillow)
        dummy_image_path = "dummy_figure.png"
        try:
            from PIL import Image, ImageDraw, ImageFont
            img = Image.new('RGB', (400, 200), color = (255, 255, 255))
            d = ImageDraw.Draw(img)
            # Try to load a default font, handle if not found
            try:
                 font = ImageFont.truetype("arial.ttf", 15) # Common font, might not exist
            except IOError:
                 font = ImageFont.load_default()
                 print("Arial font not found, using default PIL font.")
            d.text((10,10), "Simple Bar Chart", fill=(0,0,0), font=font)
            d.rectangle([50, 50, 100, 150], fill=(255,0,0)) # Bar 1
            d.text((60, 160), "A", fill=(0,0,0), font=font)
            d.rectangle([150, 80, 200, 150], fill=(0,0,255)) # Bar 2
            d.text((160, 160), "B", fill=(0,0,0), font=font)
            img.save(dummy_image_path)
            print(f"Created dummy image file: {dummy_image_path}")

            # Test the tools directly
            print("\nTesting describe_figure...")
            desc = describe_figure_tool_fn(dummy_image_path)
            print(f"Description: {desc}")
            
            print("\nTesting extract_data_points (qualitative)...")
            extract_req = "Height of bar A vs Bar B" # Qualitative request
            extract_res = extract_data_points_tool_fn(dummy_image_path, extract_req)
            print(f"Extraction Result: {extract_res}")
            
            print("\nTesting compare_elements...")
            compare_req = "Compare bar A and bar B" 
            compare_res = compare_elements_tool_fn(dummy_image_path, compare_req)
            print(f"Comparison Result: {compare_res}")

            # Clean up dummy image
            os.remove(dummy_image_path)

        except ImportError:
            print("Pillow library not installed. Skipping direct tool tests that require image creation.")
            # Optionally, still try initializing the agent
            try:
                 test_agent = initialize_figure_interpretation_agent()
                 print("\nFigure Interpretation Agent initialized successfully (tool tests skipped).")
            except Exception as e:
                 print(f"Error initializing agent: {e}")
        except Exception as e:
            print(f"Error during testing: {e}")
            if os.path.exists(dummy_image_path):
                 os.remove(dummy_image_path) # Ensure cleanup on error