File size: 21,669 Bytes
a23082c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
import os
import logging
import mimetypes
from dotenv import load_dotenv

from typing import Any, List

import gradio as gr
import requests
import pandas as pd

from llama_index.core.agent.workflow import AgentWorkflow, ToolCallResult, ToolCall, AgentOutput
from llama_index.core.base.llms.types import ChatMessage, TextBlock, ImageBlock, AudioBlock

# Assuming agent initializers are in the same directory or a known path
# Adjust import paths if necessary based on deployment structure
try:
    # Existing agents
    from agents.image_analyzer_agent import initialize_image_analyzer_agent
    from agents.reasoning_agent import initialize_reasoning_agent
    from agents.text_analyzer_agent import initialize_text_analyzer_agent
    from agents.code_agent import initialize_code_agent
    from agents.math_agent import initialize_math_agent
    from agents.planner_agent import initialize_planner_agent
    from agents.research_agent import initialize_research_agent
    from agents.role_agent import initialize_role_agent
    from agents.verifier_agent import initialize_verifier_agent
    # New agents
    from agents.advanced_validation_agent import initialize_advanced_validation_agent
    from agents.figure_interpretation_agent import initialize_figure_interpretation_agent
    from agents.long_context_management_agent import initialize_long_context_management_agent
    AGENT_IMPORT_PATH = "local"
except ImportError as e:
    # Fallback for potential different structures (e.g., nested folder)
    try:
        from final_project.image_analyzer_agent import initialize_image_analyzer_agent
        from final_project.reasoning_agent import initialize_reasoning_agent
        from final_project.text_analyzer_agent import initialize_text_analyzer_agent
        from final_project.code_agent import initialize_code_agent
        from final_project.math_agent import initialize_math_agent
        from final_project.planner_agent import initialize_planner_agent
        from final_project.research_agent import initialize_research_agent
        from final_project.role_agent import initialize_role_agent
        from final_project.verifier_agent import initialize_verifier_agent
        from final_project.advanced_validation_agent import initialize_advanced_validation_agent
        from final_project.figure_interpretation_agent import initialize_figure_interpretation_agent
        from final_project.long_context_management_agent import initialize_long_context_management_agent
        AGENT_IMPORT_PATH = "final_project"
    except ImportError as e2:
        print(f"Import Error: Could not find agent modules. Tried local and final_project paths. Error: {e2}")
        # Set initializers to None or raise error to prevent app start
        initialize_image_analyzer_agent = None
        # ... set all others to None ...
        raise RuntimeError(f"Failed to import agent modules: {e2}")

os.environ["TOKENIZERS_PARALLELISM"] = "false"
load_dotenv() # Load environment variables from .env file

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Constants ---
DEFAULT_API_URL = os.getenv("GAIA_API_URL", "https://agents-course-unit4-scoring.hf.space")

# --- Agent Initialization (Singleton Pattern) ---
# Initialize the agent workflow once
AGENT_WORKFLOW = None
try:
    logger.info(f"Initializing GAIA Multi-Agent Workflow (import path: {AGENT_IMPORT_PATH})...")
    # Existing agents
    role_agent = initialize_role_agent()
    code_agent = initialize_code_agent()
    math_agent = initialize_math_agent()
    planner_agent = initialize_planner_agent()
    research_agent = initialize_research_agent()
    text_analyzer_agent = initialize_text_analyzer_agent()
    verifier_agent = initialize_verifier_agent()
    image_analyzer_agent = initialize_image_analyzer_agent()
    reasoning_agent = initialize_reasoning_agent()
    # New agents
    advanced_validation_agent = initialize_advanced_validation_agent()
    figure_interpretation_agent = initialize_figure_interpretation_agent()
    long_context_management_agent = initialize_long_context_management_agent()

    # Check if all agents initialized successfully
    all_agents = [
        code_agent, role_agent, math_agent, planner_agent, research_agent,
        text_analyzer_agent, image_analyzer_agent, verifier_agent, reasoning_agent,
        advanced_validation_agent, figure_interpretation_agent, long_context_management_agent
    ]
    if not all(all_agents):
        raise RuntimeError("One or more agents failed to initialize.")

    AGENT_WORKFLOW = AgentWorkflow(
        agents=all_agents,
        root_agent="planner_agent" # Keep planner as root as per plan
    )
    logger.info("GAIA Multi-Agent Workflow initialized successfully.")
except Exception as e:
    logger.error(f"FATAL: Error initializing agent workflow: {e}", exc_info=True)
    # AGENT_WORKFLOW remains None, BasicAgent init will fail

# --- Basic Agent Definition (Wrapper for Workflow) ---
class BasicAgent:
    def __init__(self, workflow: AgentWorkflow):
        if workflow is None:
            logger.error("AgentWorkflow is None, initialization likely failed.")
            raise RuntimeError("AgentWorkflow failed to initialize. Check logs for details.")
        self.agent_workflow = workflow
        logger.info("BasicAgent wrapper initialized.")

    async def __call__(self, question: str | ChatMessage) -> Any:
        if isinstance(question, ChatMessage):
            log_question = str(question.blocks[0].text)[:100] if question.blocks and hasattr(question.blocks[0], "text") else str(question)[:100]
            logger.info(f"Agent received question (first 100 chars): {log_question}...")
        else:
            logger.info(f"Agent received question (first 100 chars): {question[:100]}...")

        handler = self.agent_workflow.run(user_msg=question)

        current_agent = None
        async for event in handler.stream_events():
            if (
                    hasattr(event, "current_agent_name")
                    and event.current_agent_name != current_agent
            ):
                current_agent = event.current_agent_name
                logger.info(f"{'=' * 50}\n")
                logger.info(f"{'=' * 50}\n")

            # Optional detailed logging (uncomment if needed)
            # from llama_index.core.agent.runner.base import AgentStream, AgentInput
            # if isinstance(event, AgentStream):
            #     if event.delta:
            #         logger.debug(f"STREAM: {event.delta}") # Use debug level
            # elif isinstance(event, AgentInput):
            #     logger.debug(f"πŸ“₯ Input: {event.input}") # Use debug level
            elif isinstance(event, AgentOutput):
                if event.response and hasattr(event.response, 'content') and event.response.content:
                    logger.info(f"πŸ“€ Output: {event.response.content}")
                if event.tool_calls:
                    logger.info(
                        f"πŸ› οΈ  Planning to use tools: {[call.tool_name for call in event.tool_calls]}"
                    )
            elif isinstance(event, ToolCallResult):
                logger.info(f"πŸ”§ Tool Result ({event.tool_name}):")
                logger.info(f"  Arguments: {event.tool_kwargs}")
                # Limit output logging length if potentially very long
                output_str = str(event.tool_output)
                logger.info(f"  Output: {output_str[:500]}{'...' if len(output_str) > 500 else ''}")
            elif isinstance(event, ToolCall):
                logger.info(f"πŸ”¨ Calling Tool: {event.tool_name}")
                logger.info(f"  With arguments: {event.tool_kwargs}")

        answer = await handler
        final_content = answer.response.content if hasattr(answer, 'response') and hasattr(answer.response, 'content') else str(answer)
        logger.info(f"Agent returning final answer: {final_content[:500]}{'...' if len(final_content) > 500 else ''}")
        return answer.response # Return the actual response object expected by Gradio

# --- Helper Functions for run_and_submit_all ---

async def fetch_questions(questions_url: str) -> List[dict] | None:
    """Fetches questions from the GAIA benchmark API."""
    logger.info(f"Fetching questions from: {questions_url}")
    try:
        response = requests.get(questions_url, timeout=30) # Increased timeout
        response.raise_for_status()
        questions_data = response.json()
        if not questions_data:
             logger.warning("Fetched questions list is empty.")
             return None
        logger.info(f"Fetched {len(questions_data)} questions.")
        return questions_data
    except requests.exceptions.RequestException as e:
        logger.error(f"Error fetching questions: {e}", exc_info=True)
        return None
    except requests.exceptions.JSONDecodeError as e:
         logger.error(f"Error decoding JSON response from questions endpoint: {e}", exc_info=True)
         logger.error(f"Response text: {response.text[:500]}")
         return None
    except Exception as e:
        logger.error(f"An unexpected error occurred fetching questions: {e}", exc_info=True)
        return None

async def process_question(agent: BasicAgent, item: dict, base_fetch_file_url: str) -> dict | None:
    """Processes a single question item using the agent."""
    task_id = item.get("task_id")
    question_text = item.get("question")
    file_name = item.get("file_name")

    if not task_id or question_text is None:
        logger.warning(f"Skipping item with missing task_id or question: {item}")
        return None

    message: ChatMessage
    if file_name:
        fetch_file_url = f"{base_fetch_file_url}/{task_id}"
        logger.info(f"Fetching file '{file_name}' for task {task_id} from {fetch_file_url}")
        try:
            response = requests.get(fetch_file_url, timeout=60) # Increased timeout for files
            response.raise_for_status()
            mime_type, _ = mimetypes.guess_type(file_name)
            logger.info(f"File '{file_name}' MIME type guessed as: {mime_type}")

            file_block: TextBlock | ImageBlock | AudioBlock | None = None
            if mime_type:
                # Prioritize specific extensions for text-like content
                text_extensions = (
                    ".txt", ".csv", ".json", ".xml", ".yaml", ".yml", ".ini", ".cfg", ".toml", ".log", ".properties",
                    ".html", ".htm", ".xhtml", ".css", ".scss", ".sass", ".less", ".svg", ".md", ".rst",
                    ".py", ".js", ".java", ".c", ".cpp", ".h", ".hpp", ".cs", ".go", ".php", ".rb", ".swift", ".kt",
                    ".sh", ".bat", ".ipynb", ".Rmd", ".tex"  # Added more code/markup types
                )
                if mime_type.startswith('text/') or file_name.lower().endswith(text_extensions):
                    try:
                        file_content = response.content.decode('utf-8') # Try UTF-8 first
                    except UnicodeDecodeError:
                        try:
                            file_content = response.content.decode('latin-1') # Fallback
                            logger.warning(f"Decoded file {file_name} using latin-1 fallback.")
                        except Exception as decode_err:
                            logger.error(f"Could not decode file {file_name}: {decode_err}")
                            file_content = f"[Error: Could not decode file content for {file_name}]"
                    file_block = TextBlock(block_type="text", text=file_content)
                elif mime_type.startswith('image/'):
                    # Pass image content directly for multi-modal models
                    file_block = ImageBlock(url=fetch_file_url, image=response.content)
                elif mime_type.startswith('audio/'):
                    # Pass audio content directly
                    file_block = AudioBlock(url=fetch_file_url, audio=response.content)
                elif mime_type == 'application/pdf':
                    # PDF: Pass a text block indicating the URL for agents to handle
                    logger.info(f"PDF file detected: {file_name}. Passing reference URL.")
                    file_block = TextBlock(text=f"[Reference PDF file available at: {fetch_file_url}]")
                # Add handling for other types like video if needed
                # elif mime_type.startswith('video/'):
                #     logger.info(f"Video file detected: {file_name}. Passing reference URL.")
                #     file_block = TextBlock(text=f"[Reference Video file available at: {fetch_file_url}]")

            if file_block:
                 blocks = [TextBlock(text=question_text), file_block]
                 message = ChatMessage(role="user", blocks=blocks)
            else:
                 logger.warning(f"File type for '{file_name}' (MIME: {mime_type}) not directly supported for block creation or no block created (e.g., unsupported). Passing text question only.")
                 message = ChatMessage(role="user", blocks=[TextBlock(text=question_text)])

        except requests.exceptions.RequestException as e:
            logger.error(f"Error fetching file for task {task_id}: {e}", exc_info=True)
            return {"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: Failed to fetch file {file_name} - {e}"}
        except Exception as e:
             logger.error(f"Error processing file for task {task_id}: {e}", exc_info=True)
             return {"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: Failed to process file {file_name} - {e}"}
    else:
        # No file associated with the question
        message = ChatMessage(role="user", blocks=[TextBlock(text=question_text)])

    # Run the agent on the prepared message
    try:
        logger.info(f"Running agent on task {task_id}...")
        submitted_answer_response = await agent(message)
        # Extract content safely
        submitted_answer = submitted_answer_response.content if hasattr(submitted_answer_response, 'content') else str(submitted_answer_response)

        logger.info(f"πŸ‘ Agent submitted answer for task {task_id}: {submitted_answer[:200]}{'...' if len(submitted_answer) > 200 else ''}")
        return {"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer}
    except Exception as e:
         logger.error(f"Error running agent on task {task_id}: {e}", exc_info=True)
         return {"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"}

async def submit_answers(submit_url: str, username: str, agent_code: str, results: List[dict]) -> tuple[str, pd.DataFrame]:
    """Submits the collected answers to the GAIA benchmark API."""
    answers_payload = [
        {"task_id": r["Task ID"], "submitted_answer": r["Submitted Answer"]}
        for r in results if "Submitted Answer" in r and not str(r["Submitted Answer"]).startswith("AGENT ERROR:")
    ]

    if not answers_payload:
        logger.warning("Agent did not produce any valid answers to submit.")
        results_df = pd.DataFrame(results)
        return "Agent did not produce any valid answers to submit.", results_df

    submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
    status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
    logger.info(status_update)
    logger.info(f"Submitting to: {submit_url}")

    try:
        response = requests.post(submit_url, json=submission_data, timeout=120) # Increased timeout
        response.raise_for_status()
        result_data = response.json()
        final_status = (
            f"Submission Successful!\n"
            f"User: {result_data.get('username')}\n"
            f"Overall Score: {result_data.get('score', 'N/A')}% "
            f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
            f"Message: {result_data.get('message', 'No message received.')}"
        )
        logger.info("Submission successful.")
        results_df = pd.DataFrame(results)
        return final_status, results_df
    except requests.exceptions.HTTPError as e:
        error_detail = f"Server responded with status {e.response.status_code}."
        try:
            error_json = e.response.json()
            error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
        except requests.exceptions.JSONDecodeError:
            error_detail += f" Response: {e.response.text[:500]}"
        status_message = f"Submission Failed: {error_detail}"
        logger.error(status_message)
        results_df = pd.DataFrame(results)
        return status_message, results_df
    except requests.exceptions.Timeout:
        status_message = "Submission Failed: The request timed out."
        logger.error(status_message)
        results_df = pd.DataFrame(results)
        return status_message, results_df
    except requests.exceptions.RequestException as e:
        status_message = f"Submission Failed: Network error - {e}"
        logger.error(status_message)
        results_df = pd.DataFrame(results)
        return status_message, results_df
    except Exception as e:
        status_message = f"Submission Failed: An unexpected error occurred during submission - {e}"
        logger.error(status_message, exc_info=True)
        results_df = pd.DataFrame(results)
        return status_message, results_df

# --- Main Function for Batch Processing ---
async def run_and_submit_all(
        username: str,
        agent_code: str,
        api_url: str = DEFAULT_API_URL,
        level: int = 1,
        max_questions: int = 0, # 0 means all questions for the level
        progress=gr.Progress(track_tqdm=True)
    ) -> tuple[str, pd.DataFrame]:
    """Fetches all questions for a level, runs the agent, and submits answers."""
    if not AGENT_WORKFLOW:
        error_msg = "Agent Workflow is not initialized. Cannot run benchmark."
        logger.error(error_msg)
        return error_msg, pd.DataFrame()
        
    if not username or not username.strip():
        error_msg = "Username cannot be empty."
        logger.error(error_msg)
        return error_msg, pd.DataFrame()

    questions_url = f"{api_url}/questions?level={level}"
    submit_url = f"{api_url}/submit"
    base_fetch_file_url = f"{api_url}/get_file"

    questions = await fetch_questions(questions_url)
    if questions is None:
        error_msg = f"Failed to fetch questions for level {level}. Check logs."
        return error_msg, pd.DataFrame()

    # Limit number of questions if max_questions is set
    if max_questions > 0:
        questions = questions[:max_questions]
        logger.info(f"Processing a maximum of {max_questions} questions for level {level}.")
    else:
        logger.info(f"Processing all {len(questions)} questions for level {level}.")

    agent = BasicAgent(AGENT_WORKFLOW)
    results = []
    total_questions = len(questions)

    for i, item in enumerate(progress.tqdm(questions, desc=f"Processing Level {level} Questions")):
        result = await process_question(agent, item, base_fetch_file_url)
        if result:
            results.append(result)
        # Optional: Add a small delay between questions if needed
        # await asyncio.sleep(0.1)

    # Submit answers
    final_status, results_df = await submit_answers(submit_url, username, agent_code, results)
    return final_status, results_df

# --- Gradio Interface ---
def create_gradio_interface():
    """Creates and returns the Gradio interface."""
    logger.info("Creating Gradio interface...")
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("# GAIA Benchmark Agent Runner")
        gr.Markdown("Run the initialized multi-agent system against the GAIA benchmark questions and submit the results.")

        with gr.Row():
            username = gr.Textbox(label="Username", placeholder="Enter your username (e.g., [email protected])")
            agent_code = gr.Textbox(label="Agent Code", placeholder="Enter a short code for your agent (e.g., v1.0)")
        with gr.Row():
            level = gr.Dropdown(label="Benchmark Level", choices=[1, 2, 3], value=1)
            max_questions = gr.Number(label="Max Questions (0 for all)", value=0, minimum=0, step=1)
            api_url = gr.Textbox(label="GAIA API URL", value=DEFAULT_API_URL)

        run_button = gr.Button("Run Benchmark and Submit", variant="primary")

        with gr.Accordion("Results", open=False):
            status_output = gr.Textbox(label="Submission Status", lines=5)
            results_dataframe = gr.DataFrame(label="Detailed Results")

        run_button.click(
            fn=run_and_submit_all,
            inputs=[username, agent_code, api_url, level, max_questions],
            outputs=[status_output, results_dataframe]
        )
    logger.info("Gradio interface created.")
    return demo

# --- Main Execution ---
if __name__ == "__main__":
    if not AGENT_WORKFLOW:
        print("ERROR: Agent Workflow failed to initialize. Cannot start Gradio app.")
        print("Please check logs for initialization errors (e.g., missing API keys, import issues).")
    else:
        gradio_app = create_gradio_interface()
        # Launch Gradio app
        # Share=True creates a public link (use with caution)
        # Set server_name="0.0.0.0" to allow access from network
        gradio_app.launch(server_name="0.0.0.0", server_port=7860)