Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import gradio as gr | |
import pandas as pd | |
import traceback | |
from core_agent import GAIAAgent | |
from api_integration import GAIAApiClient | |
# Constants | |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
def save_task_file(file_content, task_id): | |
""" | |
Save a task file to a temporary location | |
""" | |
if not file_content: | |
return None | |
# Create a temporary file | |
temp_dir = tempfile.gettempdir() | |
file_path = os.path.join(temp_dir, f"gaia_task_{task_id}.txt") | |
# Write content to the file | |
with open(file_path, 'wb') as f: | |
f.write(file_content) | |
print(f"File saved to {file_path}") | |
return file_path | |
def get_agent_configuration(): | |
""" | |
Get the agent configuration based on environment variables | |
""" | |
# Default configuration | |
config = { | |
"model_type": "OpenAIServerModel", # Default to OpenAIServerModel | |
"model_id": "gpt-4o", # Default model for OpenAI | |
"temperature": 0.2, | |
"executor_type": "local", | |
"verbose": False, | |
"provider": "hf-inference", # For InferenceClientModel | |
"timeout": 120 # For InferenceClientModel | |
} | |
# Check for xAI API key and base URL | |
xai_api_key = os.getenv("XAI_API_KEY") | |
xai_api_base = os.getenv("XAI_API_BASE") | |
# If we have xAI credentials, use them | |
if xai_api_key: | |
config["api_key"] = xai_api_key | |
if xai_api_base: | |
config["api_base"] = xai_api_base | |
# Use a model that works well with xAI | |
config["model_id"] = "mixtral-8x7b-32768" | |
# Override with environment variables if present | |
if os.getenv("AGENT_MODEL_TYPE"): | |
config["model_type"] = os.getenv("AGENT_MODEL_TYPE") | |
if os.getenv("AGENT_MODEL_ID"): | |
config["model_id"] = os.getenv("AGENT_MODEL_ID") | |
if os.getenv("AGENT_TEMPERATURE"): | |
config["temperature"] = float(os.getenv("AGENT_TEMPERATURE")) | |
if os.getenv("AGENT_EXECUTOR_TYPE"): | |
config["executor_type"] = os.getenv("AGENT_EXECUTOR_TYPE") | |
if os.getenv("AGENT_VERBOSE") is not None: | |
config["verbose"] = os.getenv("AGENT_VERBOSE").lower() == "true" | |
if os.getenv("AGENT_API_BASE"): | |
config["api_base"] = os.getenv("AGENT_API_BASE") | |
# InferenceClientModel specific settings | |
if os.getenv("AGENT_PROVIDER"): | |
config["provider"] = os.getenv("AGENT_PROVIDER") | |
if os.getenv("AGENT_TIMEOUT"): | |
config["timeout"] = int(os.getenv("AGENT_TIMEOUT")) | |
return config | |
def run_and_submit_all(profile: gr.OAuthProfile | None): | |
""" | |
Fetches all questions, runs the GAIAAgent on them, submits all answers, | |
and displays the results. | |
""" | |
# Check for user login | |
if not profile: | |
return "Please Login to Hugging Face with the button.", None | |
username = profile.username | |
print(f"User logged in: {username}") | |
# Get SPACE_ID for code link | |
space_id = os.getenv("SPACE_ID") | |
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" | |
# Initialize API client | |
api_client = GAIAApiClient(DEFAULT_API_URL) | |
# Initialize Agent with configuration | |
try: | |
agent_config = get_agent_configuration() | |
print(f"Using agent configuration: {agent_config}") | |
agent = GAIAAgent(**agent_config) | |
print("Agent initialized successfully") | |
except Exception as e: | |
error_details = traceback.format_exc() | |
print(f"Error initializing agent: {e}\n{error_details}") | |
return f"Error initializing agent: {e}", None | |
# Fetch questions | |
try: | |
questions_data = api_client.get_questions() | |
if not questions_data: | |
return "Fetched questions list is empty or invalid format.", None | |
print(f"Fetched {len(questions_data)} questions.") | |
except Exception as e: | |
error_details = traceback.format_exc() | |
print(f"Error fetching questions: {e}\n{error_details}") | |
return f"Error fetching questions: {e}", None | |
# Run agent on questions | |
results_log = [] | |
answers_payload = [] | |
print(f"Running agent on {len(questions_data)} questions...") | |
# Progress tracking | |
total_questions = len(questions_data) | |
completed = 0 | |
failed = 0 | |
for item in questions_data: | |
task_id = item.get("task_id") | |
question_text = item.get("question") | |
if not task_id or question_text is None: | |
print(f"Skipping item with missing task_id or question: {item}") | |
continue | |
try: | |
# Update progress | |
completed += 1 | |
print(f"Processing question {completed}/{total_questions}: Task ID {task_id}") | |
# Check if the question has an associated file | |
file_path = None | |
try: | |
file_content = api_client.get_file(task_id) | |
print(f"Downloaded file for task {task_id}") | |
file_path = save_task_file(file_content, task_id) | |
except Exception as file_e: | |
print(f"No file found for task {task_id} or error: {file_e}") | |
# Run the agent to get the answer | |
submitted_answer = agent.answer_question(question_text, file_path) | |
# Add to results | |
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer}) | |
results_log.append({ | |
"Task ID": task_id, | |
"Question": question_text, | |
"Submitted Answer": submitted_answer | |
}) | |
except Exception as e: | |
# Update error count | |
failed += 1 | |
error_details = traceback.format_exc() | |
print(f"Error running agent on task {task_id}: {e}\n{error_details}") | |
# Add error to results | |
error_msg = f"AGENT ERROR: {e}" | |
answers_payload.append({"task_id": task_id, "submitted_answer": error_msg}) | |
results_log.append({ | |
"Task ID": task_id, | |
"Question": question_text, | |
"Submitted Answer": error_msg | |
}) | |
# Print summary | |
print(f"\nProcessing complete: {completed} questions processed, {failed} failures") | |
if not answers_payload: | |
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log) | |
# Submit answers | |
submission_data = { | |
"username": username.strip(), | |
"agent_code": agent_code, | |
"answers": answers_payload | |
} | |
print(f"Submitting {len(answers_payload)} answers for username '{username}'...") | |
try: | |
result_data = api_client.submit_answers( | |
username.strip(), | |
agent_code, | |
answers_payload | |
) | |
# Calculate success rate | |
correct_count = result_data.get('correct_count', 0) | |
total_attempted = result_data.get('total_attempted', len(answers_payload)) | |
success_rate = (correct_count / total_attempted) * 100 if total_attempted > 0 else 0 | |
final_status = ( | |
f"Submission Successful!\n" | |
f"User: {result_data.get('username')}\n" | |
f"Overall Score: {result_data.get('score', 'N/A')}% " | |
f"({correct_count}/{total_attempted} correct, {success_rate:.1f}% success rate)\n" | |
f"Message: {result_data.get('message', 'No message received.')}" | |
) | |
print("Submission successful.") | |
return final_status, pd.DataFrame(results_log) | |
except Exception as e: | |
error_details = traceback.format_exc() | |
status_message = f"Submission Failed: {e}\n{error_details}" | |
print(status_message) | |
return status_message, pd.DataFrame(results_log) | |
# Build Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# GAIA Agent Evaluation Runner") | |
gr.Markdown( | |
""" | |
**Instructions:** | |
1. Log in to your Hugging Face account using the button below. | |
2. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score. | |
**Configuration:** | |
You can configure the agent by setting these environment variables: | |
- `AGENT_MODEL_TYPE`: Model type (HfApiModel, InferenceClientModel, LiteLLMModel, OpenAIServerModel) | |
- `AGENT_MODEL_ID`: Model ID | |
- `AGENT_TEMPERATURE`: Temperature for generation (0.0-1.0) | |
- `AGENT_EXECUTOR_TYPE`: Type of executor ('local' or 'e2b') | |
- `AGENT_VERBOSE`: Enable verbose logging (true/false) | |
- `AGENT_API_BASE`: Base URL for API calls (for OpenAIServerModel) | |
**xAI Support:** | |
- `XAI_API_KEY`: Your xAI API key | |
- `XAI_API_BASE`: Base URL for xAI API (default: https://api.groq.com/openai/v1) | |
- When using xAI, set AGENT_MODEL_TYPE=OpenAIServerModel and AGENT_MODEL_ID=mixtral-8x7b-32768 | |
**InferenceClientModel specific settings:** | |
- `AGENT_PROVIDER`: Provider for InferenceClientModel (e.g., "hf-inference") | |
- `AGENT_TIMEOUT`: Timeout in seconds for API calls | |
""" | |
) | |
gr.LoginButton() | |
run_button = gr.Button("Run Evaluation & Submit All Answers") | |
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False) | |
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True) | |
run_button.click( | |
fn=run_and_submit_all, | |
outputs=[status_output, results_table] | |
) | |
if __name__ == "__main__": | |
print("\n" + "-"*30 + " App Starting " + "-"*30) | |
# Check for environment variables | |
config = get_agent_configuration() | |
print(f"Agent configuration: {config}") | |
# Run the Gradio app | |
demo.launch(debug=True, share=False) | |