Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import requests | |
import inspect | |
import time | |
import pandas as pd | |
from smolagents import DuckDuckGoSearchTool | |
import threading | |
from typing import Dict, List, Optional, Tuple, Union | |
import json | |
from huggingface_hub import InferenceClient | |
import base64 | |
from PIL import Image | |
import io | |
import tempfile | |
import urllib.parse | |
from pathlib import Path | |
# --- Constants --- | |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
# --- Global Cache for Answers --- | |
cached_answers = {} | |
cached_questions = [] | |
processing_status = {"is_processing": False, "progress": 0, "total": 0} | |
# --- File Download Utility --- | |
def download_attachment(url: str, temp_dir: str) -> Optional[str]: | |
""" | |
Download an attachment from URL to a temporary directory. | |
Returns the local file path if successful, None otherwise. | |
""" | |
try: | |
response = requests.get(url, timeout=30) | |
response.raise_for_status() | |
# Extract filename from URL or create one based on content type | |
parsed_url = urllib.parse.urlparse(url) | |
filename = os.path.basename(parsed_url.path) | |
if not filename or '.' not in filename: | |
# Try to determine extension from content type | |
content_type = response.headers.get('content-type', '').lower() | |
if 'image' in content_type: | |
if 'jpeg' in content_type or 'jpg' in content_type: | |
filename = f"attachment_{int(time.time())}.jpg" | |
elif 'png' in content_type: | |
filename = f"attachment_{int(time.time())}.png" | |
else: | |
filename = f"attachment_{int(time.time())}.img" | |
elif 'audio' in content_type: | |
if 'mp3' in content_type: | |
filename = f"attachment_{int(time.time())}.mp3" | |
elif 'wav' in content_type: | |
filename = f"attachment_{int(time.time())}.wav" | |
else: | |
filename = f"attachment_{int(time.time())}.audio" | |
elif 'python' in content_type or 'text' in content_type: | |
filename = f"attachment_{int(time.time())}.py" | |
else: | |
filename = f"attachment_{int(time.time())}.file" | |
file_path = os.path.join(temp_dir, filename) | |
with open(file_path, 'wb') as f: | |
f.write(response.content) | |
print(f"Downloaded attachment: {url} -> {file_path}") | |
return file_path | |
except Exception as e: | |
print(f"Failed to download attachment {url}: {e}") | |
return None | |
# --- Code Processing Tool --- | |
class CodeAnalysisTool: | |
def __init__(self, model_name: str = "meta-llama/Llama-3.1-8B-Instruct"): | |
self.client = InferenceClient(model=model_name, provider="sambanova") | |
def analyze_code(self, code_path: str) -> str: | |
""" | |
Analyze Python code and return insights. | |
""" | |
try: | |
with open(code_path, 'r', encoding='utf-8') as f: | |
code_content = f.read() | |
# Limit code length for analysis | |
if len(code_content) > 5000: | |
code_content = code_content[:5000] + "\n... (truncated)" | |
analysis_prompt = f"""Analyze this Python code and provide a concise summary of: | |
1. What the code does (main functionality) | |
2. Key functions/classes | |
3. Any notable patterns or issues | |
4. Input/output behavior if applicable | |
Code: | |
```python | |
{code_content} | |
``` | |
Provide a brief, focused analysis:""" | |
messages = [{"role": "user", "content": analysis_prompt}] | |
response = self.client.chat_completion( | |
messages=messages, | |
max_tokens=500, | |
temperature=0.3 | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
return f"Code analysis failed: {e}" | |
# --- Image Processing Tool --- | |
class ImageAnalysisTool: | |
def __init__(self, model_name: str = "microsoft/Florence-2-large"): | |
self.client = InferenceClient(model=model_name) | |
def analyze_image(self, image_path: str, prompt: str = "Describe this image in detail") -> str: | |
""" | |
Analyze an image and return a description. | |
""" | |
try: | |
# Open and process the image | |
with open(image_path, "rb") as f: | |
image_bytes = f.read() | |
# Use the vision model to analyze the image | |
response = self.client.image_to_text( | |
image=image_bytes, | |
model="microsoft/Florence-2-large" | |
) | |
return response.get("generated_text", "Could not analyze image") | |
except Exception as e: | |
try: | |
# Fallback: use a different vision model | |
response = self.client.image_to_text( | |
image=image_bytes, | |
model="Salesforce/blip-image-captioning-large" | |
) | |
return response.get("generated_text", f"Image analysis error: {e}") | |
except: | |
return f"Image analysis failed: {e}" | |
def extract_text_from_image(self, image_path: str) -> str: | |
""" | |
Extract text from an image using OCR. | |
""" | |
try: | |
with open(image_path, "rb") as f: | |
image_bytes = f.read() | |
# Use an OCR model | |
response = self.client.image_to_text( | |
image=image_bytes, | |
model="microsoft/trocr-base-printed" | |
) | |
return response.get("generated_text", "No text found in image") | |
except Exception as e: | |
return f"OCR failed: {e}" | |
# --- Audio Processing Tool --- | |
class AudioTranscriptionTool: | |
def __init__(self, model_name: str = "openai/whisper-large-v3"): | |
self.client = InferenceClient(model=model_name) | |
def transcribe_audio(self, audio_path: str) -> str: | |
""" | |
Transcribe audio file to text. | |
""" | |
try: | |
with open(audio_path, "rb") as f: | |
audio_bytes = f.read() | |
# Use Whisper for transcription | |
response = self.client.automatic_speech_recognition( | |
audio=audio_bytes | |
) | |
return response.get("text", "Could not transcribe audio") | |
except Exception as e: | |
try: | |
# Fallback to a different ASR model | |
response = self.client.automatic_speech_recognition( | |
audio=audio_bytes, | |
model="facebook/wav2vec2-large-960h-lv60-self" | |
) | |
return response.get("text", f"Audio transcription error: {e}") | |
except: | |
return f"Audio transcription failed: {e}" | |
# --- Enhanced Intelligent Agent with Media Processing --- | |
class IntelligentAgent: | |
def __init__(self, debug: bool = True, model_name: str = "meta-llama/Llama-3.1-8B-Instruct"): | |
self.search = DuckDuckGoSearchTool() | |
self.client = InferenceClient(model=model_name, provider="sambanova") | |
self.image_tool = ImageAnalysisTool() | |
self.audio_tool = AudioTranscriptionTool() | |
self.code_tool = CodeAnalysisTool(model_name) | |
self.debug = debug | |
if self.debug: | |
print(f"IntelligentAgent initialized with model: {model_name}") | |
def _chat_completion(self, prompt: str, max_tokens: int = 500, temperature: float = 0.3) -> str: | |
""" | |
Use chat completion instead of text generation to avoid provider compatibility issues. | |
""" | |
try: | |
messages = [{"role": "user", "content": prompt}] | |
# Try chat completion first | |
try: | |
response = self.client.chat_completion( | |
messages=messages, | |
max_tokens=max_tokens, | |
temperature=temperature | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as chat_error: | |
if self.debug: | |
print(f"Chat completion failed: {chat_error}, trying text generation...") | |
# Fallback to text generation | |
response = self.client.conversational( | |
prompt, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
do_sample=temperature > 0 | |
) | |
return response.strip() | |
except Exception as e: | |
if self.debug: | |
print(f"Both chat completion and text generation failed: {e}") | |
raise e | |
def _detect_and_download_attachments(self, question_data: dict) -> Tuple[List[str], List[str], List[str]]: | |
""" | |
Detect and download attachments from question data. | |
Returns (image_files, audio_files, code_files) | |
""" | |
image_files = [] | |
audio_files = [] | |
code_files = [] | |
# Create temporary directory for downloads | |
temp_dir = tempfile.mkdtemp(prefix="agent_attachments_") | |
# Check for attachments in various fields | |
attachments = [] | |
# Common fields where attachments might be found | |
attachment_fields = ['attachments', 'files', 'media', 'resources'] | |
for field in attachment_fields: | |
if field in question_data: | |
field_data = question_data[field] | |
if isinstance(field_data, list): | |
attachments.extend(field_data) | |
elif isinstance(field_data, str): | |
attachments.append(field_data) | |
# Also check if the question text contains URLs | |
question_text = question_data.get('question', '') | |
if 'http' in question_text: | |
import re | |
urls = re.findall(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', question_text) | |
attachments.extend(urls) | |
# Download and categorize attachments | |
for attachment in attachments: | |
if isinstance(attachment, dict): | |
url = attachment.get('url') or attachment.get('link') or attachment.get('file_url') | |
file_type = attachment.get('type', '').lower() | |
else: | |
url = attachment | |
file_type = '' | |
if not url: | |
continue | |
# Download the file | |
file_path = download_attachment(url, temp_dir) | |
if not file_path: | |
continue | |
# Categorize based on extension or type | |
file_ext = Path(file_path).suffix.lower() | |
if file_type: | |
if 'image' in file_type or file_ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']: | |
image_files.append(file_path) | |
elif 'audio' in file_type or file_ext in ['.mp3', '.wav', '.m4a', '.ogg', '.flac']: | |
audio_files.append(file_path) | |
elif 'python' in file_type or 'code' in file_type or file_ext in ['.py', '.txt']: | |
code_files.append(file_path) | |
else: | |
# Auto-detect based on extension | |
if file_ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']: | |
image_files.append(file_path) | |
elif file_ext in ['.mp3', '.wav', '.m4a', '.ogg', '.flac']: | |
audio_files.append(file_path) | |
elif file_ext in ['.py', '.txt']: | |
code_files.append(file_path) | |
if self.debug: | |
print(f"Downloaded attachments: {len(image_files)} images, {len(audio_files)} audio, {len(code_files)} code files") | |
return image_files, audio_files, code_files | |
def _process_attachments(self, image_files: List[str] = None, audio_files: List[str] = None, code_files: List[str] = None) -> str: | |
""" | |
Process all types of attachments and return their content as text. | |
""" | |
attachment_content = [] | |
# Process code files | |
if code_files: | |
for code_file in code_files: | |
if code_file and os.path.exists(code_file): | |
try: | |
# First, include the raw code content (truncated) | |
with open(code_file, 'r', encoding='utf-8') as f: | |
code_content = f.read() | |
if len(code_content) > 1000: | |
code_preview = code_content[:1000] + "\n... (truncated)" | |
else: | |
code_preview = code_content | |
attachment_content.append(f"Code File Content:\n```python\n{code_preview}\n```") | |
# Then add analysis | |
code_analysis = self.code_tool.analyze_code(code_file) | |
attachment_content.append(f"Code Analysis: {code_analysis}") | |
except Exception as e: | |
attachment_content.append(f"Error processing code file {code_file}: {e}") | |
# Process images | |
if image_files: | |
for image_file in image_files: | |
if image_file and os.path.exists(image_file): | |
try: | |
# Analyze the image | |
image_description = self.image_tool.analyze_image(image_file) | |
attachment_content.append(f"Image Analysis: {image_description}") | |
# Try to extract text from image | |
extracted_text = self.image_tool.extract_text_from_image(image_file) | |
if extracted_text and "No text found" not in extracted_text: | |
attachment_content.append(f"Text from Image: {extracted_text}") | |
except Exception as e: | |
attachment_content.append(f"Error processing image {image_file}: {e}") | |
# Process audio files | |
if audio_files: | |
for audio_file in audio_files: | |
if audio_file and os.path.exists(audio_file): | |
try: | |
# Transcribe the audio | |
transcription = self.audio_tool.transcribe_audio(audio_file) | |
attachment_content.append(f"Audio Transcription: {transcription}") | |
except Exception as e: | |
attachment_content.append(f"Error processing audio {audio_file}: {e}") | |
return "\n\n".join(attachment_content) if attachment_content else "" | |
def _should_search(self, question: str, attachment_context: str = "") -> bool: | |
""" | |
Use LLM to determine if search is needed for the question, considering attachment context. | |
Returns True if search is recommended, False otherwise. | |
""" | |
decision_prompt = f"""Analyze this question and decide if it requires real-time information, recent data, or specific facts that might not be in your training data. | |
SEARCH IS NEEDED for: | |
- Current events, news, recent developments | |
- Real-time data (weather, stock prices, sports scores) | |
- Specific factual information that changes frequently | |
- Recent product releases, company information | |
- Current status of people, organizations, or projects | |
- Location-specific current information | |
SEARCH IS NOT NEEDED for: | |
- General knowledge questions | |
- Mathematical calculations | |
- Programming concepts and syntax | |
- Historical facts (older than 1 year) | |
- Definitions of well-established concepts | |
- How-to instructions for common tasks | |
- Creative writing or opinion-based responses | |
- Questions that can be answered from attached files (code, images, audio) | |
- Code analysis, debugging, or explanation questions | |
- Questions about uploaded content | |
Question: "{question}" | |
{f"Attachment Context Available: {attachment_context[:500]}..." if attachment_context else "No attachment context available."} | |
Respond with only "SEARCH" or "NO_SEARCH" followed by a brief reason (max 20 words). | |
Example responses: | |
- "SEARCH - Current weather data needed" | |
- "NO_SEARCH - Mathematical concept, general knowledge sufficient" | |
- "NO_SEARCH - Can be answered from attached code/image content" | |
""" | |
try: | |
response = self._chat_completion(decision_prompt, max_tokens=50, temperature=0.1) | |
decision = response.strip().upper() | |
should_search = decision.startswith("SEARCH") | |
time.sleep(5) | |
if self.debug: | |
print(f"Decision for '{question}': {decision}") | |
return should_search | |
except Exception as e: | |
if self.debug: | |
print(f"Error in search decision: {e}, defaulting to no search for attachment questions") | |
# Default to no search if decision fails and there are attachments | |
return len(attachment_context) == 0 | |
def _answer_with_llm(self, question: str, attachment_context: str = "") -> str: | |
""" | |
Generate answer using LLM without search, considering attachment context. | |
""" | |
context_section = f"\n\nAttachment Context:\n{attachment_context}" if attachment_context else "" | |
answer_prompt = f"""You are a general AI assistant. I will ask you a question. | |
YOUR ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. | |
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. Do not add a dot after the numbers. | |
{context_section} | |
Question: {question} | |
Answer:""" | |
try: | |
response = self._chat_completion(answer_prompt, max_tokens=500, temperature=0.3) | |
return response | |
except Exception as e: | |
return f"Sorry, I encountered an error generating the response: {e}" | |
def _answer_with_search(self, question: str, attachment_context: str = "") -> str: | |
""" | |
Generate answer using search results and LLM, considering attachment context. | |
""" | |
try: | |
# Perform search | |
time.sleep(10) | |
search_results = self.search(question) | |
if self.debug: | |
print(f"Search results type: {type(search_results)}") | |
if not search_results: | |
return "No search results found. Let me try to answer based on my knowledge:\n\n" + self._answer_with_llm(question, attachment_context) | |
# Format search results - handle different result formats | |
if isinstance(search_results, str): | |
search_context = search_results | |
else: | |
# Handle list of results | |
formatted_results = [] | |
for i, result in enumerate(search_results[:3]): # Use top 3 results | |
if isinstance(result, dict): | |
title = result.get("title", "No title") | |
snippet = result.get("snippet", "").strip() | |
link = result.get("link", "") | |
formatted_results.append(f"Title: {title}\nContent: {snippet}\nSource: {link}") | |
elif isinstance(result, str): | |
formatted_results.append(result) | |
else: | |
formatted_results.append(str(result)) | |
search_context = "\n\n".join(formatted_results) | |
# Generate answer using search context and attachment context | |
context_section = f"\n\nAttachment Context:\n{attachment_context}" if attachment_context else "" | |
answer_prompt = f"""You are a general AI assistant. I will ask you a question. | |
Based on the search results and the context section below, provide an answer to the question. | |
If the search results don't fully answer the question, you can supplement with your general knowledge. | |
Your ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. | |
Do not add dot if your answer is a number. | |
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. | |
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. | |
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. | |
Do nor report on your thoughts. | |
Question: {question} | |
Search Results: | |
{search_context} | |
{context_section} | |
Answer:""" | |
try: | |
response = self._chat_completion(answer_prompt, max_tokens=600, temperature=0.3) | |
return response | |
except Exception as e: | |
if self.debug: | |
print(f"LLM generation error: {e}") | |
# Fallback to simple search result formatting | |
if search_results: | |
if isinstance(search_results, str): | |
return search_results | |
elif isinstance(search_results, list) and len(search_results) > 0: | |
first_result = search_results[0] | |
if isinstance(first_result, dict): | |
title = first_result.get("title", "Search Result") | |
snippet = first_result.get("snippet", "").strip() | |
link = first_result.get("link", "") | |
return f"**{title}**\n\n{snippet}\n\n{f'Source: {link}' if link else ''}" | |
else: | |
return str(first_result) | |
else: | |
return str(search_results) | |
else: | |
return "Search completed but no usable results found." | |
except Exception as e: | |
return f"Search failed: {e}. Let me try to answer based on my knowledge:\n\n" + self._answer_with_llm(question, attachment_context) | |
def process_question_with_attachments(self, question_data: dict) -> str: | |
""" | |
Process a question that may have attachments. | |
""" | |
question_text = question_data.get('question', '') | |
if self.debug: | |
print(f"Processing question with potential attachments: {question_text[:100]}...") | |
try: | |
# Detect and download attachments | |
image_files, audio_files, code_files = self._detect_and_download_attachments(question_data) | |
# Process attachments to get context | |
attachment_context = self._process_attachments(image_files, audio_files, code_files) | |
if self.debug and attachment_context: | |
print(f"Attachment context: {attachment_context[:200]}...") | |
# Decide whether to search | |
if self._should_search(question_text, attachment_context): | |
if self.debug: | |
print("Using search-based approach") | |
answer = self._answer_with_search(question_text, attachment_context) | |
else: | |
if self.debug: | |
print("Using LLM-only approach") | |
answer = self._answer_with_llm(question_text, attachment_context) | |
# Cleanup temporary files | |
if image_files or audio_files or code_files: | |
try: | |
all_files = image_files + audio_files + code_files | |
temp_dirs = set(os.path.dirname(f) for f in all_files) | |
for temp_dir in temp_dirs: | |
import shutil | |
shutil.rmtree(temp_dir, ignore_errors=True) | |
except Exception as cleanup_error: | |
if self.debug: | |
print(f"Cleanup error: {cleanup_error}") | |
except Exception as e: | |
answer = f"Sorry, I encountered an error: {e}" | |
if self.debug: | |
print(f"Agent returning answer: {answer[:100]}...") | |
return answer | |
def __call__(self, question: str, image_files: List[str] = None, audio_files: List[str] = None) -> str: | |
""" | |
Main entry point for manual testing - process media files and generate response. | |
""" | |
if self.debug: | |
print(f"Agent received question: {question}") | |
print(f"Image files: {image_files}") | |
print(f"Audio files: {audio_files}") | |
# Early validation | |
if not question or not question.strip(): | |
return "Please provide a valid question." | |
try: | |
# Process media files first | |
attachment_context = self._process_attachments(image_files, audio_files, []) | |
if self.debug and attachment_context: | |
print(f"Media context: {attachment_context[:200]}...") | |
# Decide whether to search | |
if self._should_search(question, attachment_context): | |
if self.debug: | |
print("Using search-based approach") | |
answer = self._answer_with_search(question, attachment_context) | |
else: | |
if self.debug: | |
print("Using LLM-only approach") | |
answer = self._answer_with_llm(question, attachment_context) | |
except Exception as e: | |
answer = f"Sorry, I encountered an error: {e}" | |
if self.debug: | |
print(f"Agent returning answer: {answer[:100]}...") | |
return answer | |
def fetch_questions() -> Tuple[str, Optional[pd.DataFrame]]: | |
""" | |
Fetch questions from the API and cache them. | |
""" | |
global cached_questions | |
api_url = DEFAULT_API_URL | |
questions_url = f"{api_url}/questions" | |
print(f"Fetching questions from: {questions_url}") | |
try: | |
response = requests.get(questions_url, timeout=15) | |
response.raise_for_status() | |
questions_data = response.json() | |
if not questions_data: | |
return "Fetched questions list is empty.", None | |
cached_questions = questions_data | |
# Create DataFrame for display | |
display_data = [] | |
for item in questions_data: | |
# Check for attachments | |
has_attachments = False | |
attachment_info = "" | |
# Check various fields for attachments | |
attachment_fields = ['attachments', 'files', 'media', 'resources'] | |
for field in attachment_fields: | |
if field in item and item[field]: | |
has_attachments = True | |
if isinstance(item[field], list): | |
attachment_info += f"{len(item[field])} {field}, " | |
else: | |
attachment_info += f"{field}, " | |
# Check if question contains URLs | |
question_text = item.get("question", "") | |
if 'http' in question_text: | |
has_attachments = True | |
attachment_info += "URLs in text, " | |
if attachment_info: | |
attachment_info = attachment_info.rstrip(", ") | |
display_data.append({ | |
"Task ID": item.get("task_id", "Unknown"), | |
"Question": question_text[:100] + "..." if len(question_text) > 100 else question_text, | |
"Has Attachments": "Yes" if has_attachments else "No", | |
"Attachment Info": attachment_info | |
}) | |
df = pd.DataFrame(display_data) | |
attachment_count = sum(1 for item in display_data if item["Has Attachments"] == "Yes") | |
status_msg = f"Successfully fetched {len(questions_data)} questions. {attachment_count} questions have attachments. Ready to generate answers." | |
return status_msg, df | |
except requests.exceptions.RequestException as e: | |
return f"Error fetching questions: {e}", None | |
except Exception as e: | |
return f"An unexpected error occurred: {e}", None | |
def generate_answers_async(model_name: str = "meta-llama/Llama-3.1-8B-Instruct", progress_callback=None): | |
""" | |
Generate answers for all cached questions asynchronously using the intelligent agent. | |
""" | |
global cached_answers, processing_status | |
if not cached_questions: | |
return "No questions available. Please fetch questions first." | |
processing_status["is_processing"] = True | |
processing_status["progress"] = 0 | |
processing_status["total"] = len(cached_questions) | |
try: | |
agent = IntelligentAgent(debug=True, model_name=model_name) | |
cached_answers = {} | |
for i, question_data in enumerate(cached_questions): | |
if not processing_status["is_processing"]: # Check if cancelled | |
break | |
task_id = question_data.get("task_id") | |
question_text = question_data.get("question") | |
if not task_id or question_text is None: | |
continue | |
try: | |
# Use the new method that handles attachments | |
answer = agent.process_question_with_attachments(question_data) | |
cached_answers[task_id] = { | |
"question": question_text, | |
"answer": answer | |
} | |
except Exception as e: | |
cached_answers[task_id] = { | |
"question": question_text, | |
"answer": f"AGENT ERROR: {e}" | |
} | |
processing_status["progress"] = i + 1 | |
if progress_callback: | |
progress_callback(i + 1, len(cached_questions)) | |
except Exception as e: | |
print(f"Error in generate_answers_async: {e}") | |
finally: | |
processing_status["is_processing"] = False | |
def start_answer_generation(model_choice: str): | |
""" | |
Start the answer generation process in a separate thread. | |
""" | |
if processing_status["is_processing"]: | |
return "Answer generation is already in progress." | |
if not cached_questions: | |
return "No questions available. Please fetch questions first." | |
# Map model choice to actual model name | |
model_map = { | |
"Llama 3.1 8B": "meta-llama/Llama-3.1-8B-Instruct", | |
"Llama 3.3 70B": "meta-llama/Llama-3.3-70B-Instruct", | |
"Mistral 7B": "mistralai/Mistral-7B-Instruct-v0.3" | |
} | |
selected_model = model_map.get(model_choice, "meta-llama/Llama-3.1-8B-Instruct") | |
# Start generation in background thread | |
thread = threading.Thread(target=generate_answers_async, args=(selected_model,)) | |
thread.daemon = True | |
thread.start() | |
return f"Answer generation started using {model_choice}. Check progress." | |
def get_generation_progress(): | |
""" | |
Get the current progress of answer generation. | |
""" | |
if not processing_status["is_processing"] and processing_status["progress"] == 0: | |
return "Not started" | |
if processing_status["is_processing"]: | |
progress = processing_status["progress"] | |
total = processing_status["total"] | |
status_msg = f"Generating answers... {progress}/{total} completed" | |
return status_msg | |
else: | |
# Generation completed | |
if cached_answers: | |
# Create DataFrame with results | |
display_data = [] | |
for task_id, data in cached_answers.items(): | |
display_data.append({ | |
"Task ID": task_id, | |
"Question": data["question"][:100] + "..." if len(data["question"]) > 100 else data["question"], | |
"Generated Answer": data["answer"][:200] + "..." if len(data["answer"]) > 200 else data["answer"] | |
}) | |
df = pd.DataFrame(display_data) | |
status_msg = f"Answer generation completed! {len(cached_answers)} answers ready for submission." | |
return status_msg, df | |
else: | |
return "Answer generation completed but no answers were generated." | |
def submit_cached_answers(profile: gr.OAuthProfile | None): | |
""" | |
Submit the cached answers to the evaluation API. | |
""" | |
global cached_answers | |
if not profile: | |
return "Please log in to Hugging Face first.", None | |
if not cached_answers: | |
return "No cached answers available. Please generate answers first.", None | |
username = profile.username | |
space_id = os.getenv("SPACE_ID") | |
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" if space_id else "Unknown" | |
# Prepare submission payload | |
answers_payload = [] | |
for task_id, data in cached_answers.items(): | |
answers_payload.append({ | |
"task_id": task_id, | |
"submitted_answer": data["answer"] | |
}) | |
submission_data = { | |
"username": username.strip(), | |
"agent_code": agent_code, | |
"answers": answers_payload | |
} | |
# Submit to API | |
api_url = DEFAULT_API_URL | |
submit_url = f"{api_url}/submit" | |
print(f"Submitting {len(answers_payload)} answers to: {submit_url}") | |
try: | |
response = requests.post(submit_url, json=submission_data, timeout=60) | |
response.raise_for_status() | |
result_data = response.json() | |
final_status = ( | |
f"Submission Successful!\n" | |
f"User: {result_data.get('username')}\n" | |
f"Overall Score: {result_data.get('score', 'N/A')}% " | |
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" | |
f"Message: {result_data.get('message', 'No message received.')}" | |
) | |
# Create results DataFrame | |
results_log = [] | |
for task_id, data in cached_answers.items(): | |
results_log.append({ | |
"Task ID": task_id, | |
"Question": data["question"], | |
"Submitted Answer": data["answer"] | |
}) | |
results_df = pd.DataFrame(results_log) | |
return final_status, results_df | |
except requests.exceptions.HTTPError as e: | |
error_detail = f"Server responded with status {e.response.status_code}." | |
try: | |
error_json = e.response.json() | |
error_detail += f" Detail: {error_json.get('detail', e.response.text)}" | |
except: | |
error_detail += f" Response: {e.response.text[:500]}" | |
return f"Submission Failed: {error_detail}", None | |
except requests.exceptions.Timeout: | |
return "Submission Failed: The request timed out.", None | |
except Exception as e: | |
return f"Submission Failed: {e}", None | |
def clear_cache(): | |
""" | |
Clear all cached data. | |
""" | |
global cached_answers, cached_questions, processing_status | |
cached_answers = {} | |
cached_questions = [] | |
processing_status = {"is_processing": False, "progress": 0, "total": 0} | |
return "Cache cleared successfully.", None | |
def test_media_processing(image_files, audio_files, question): | |
""" | |
Test the media processing functionality with uploaded files. | |
""" | |
if not question: | |
question = "What can you tell me about the uploaded media?" | |
agent = IntelligentAgent(debug=True) | |
# Convert file paths to lists | |
image_paths = [img.name for img in image_files] if image_files else None | |
audio_paths = [aud.name for aud in audio_files] if audio_files else None | |
try: | |
result = agent(question, image_files=image_paths, audio_files=audio_paths) | |
return result | |
except Exception as e: | |
return f"Error processing media: {e}" | |
# --- Enhanced Gradio Interface --- | |
with gr.Blocks(title="Intelligent Agent with Media Processing") as demo: | |
gr.Markdown("# Intelligent Agent with Conditional Search and Media Processing") | |
gr.Markdown("This agent can process images and audio files, uses an LLM to decide when search is needed, optimizing for both accuracy and efficiency.") | |
with gr.Row(): | |
gr.LoginButton() | |
clear_btn = gr.Button("Clear Cache", variant="secondary") | |
with gr.Tab("Step 1: Fetch Questions"): | |
gr.Markdown("### Fetch Questions from API") | |
fetch_btn = gr.Button("Fetch Questions", variant="primary") | |
fetch_status = gr.Textbox(label="Fetch Status", lines=2, interactive=False) | |
fetch_btn.click( | |
fn=fetch_questions, | |
outputs=[fetch_status, questions_table] | |
) | |
with gr.Tab("Step 2: Generate Answers"): | |
gr.Markdown("### Generate Answers with Intelligent Search Decision") | |
with gr.Row(): | |
model_choice = gr.Dropdown( | |
choices=["Llama 3.1 8B", "Llama 3.3 70B", "Mistral 7B"], | |
value="Llama 3.1 8B", | |
label="Select Model" | |
) | |
generate_btn = gr.Button("Start Answer Generation", variant="primary") | |
refresh_btn = gr.Button("Refresh Progress", variant="secondary") | |
generation_status = gr.Textbox(label="Generation Status", lines=2, interactive=False) | |
answers_table = gr.DataFrame(label="Generated Answers", wrap=True) | |
generate_btn.click( | |
fn=start_answer_generation, | |
inputs=[model_choice], | |
outputs=generation_status | |
) | |
refresh_btn.click( | |
fn=get_generation_progress, | |
outputs=[generation_status, answers_table] | |
) | |
with gr.Tab("Step 3: Submit Results"): | |
gr.Markdown("### Submit Generated Answers") | |
submit_btn = gr.Button("Submit Answers", variant="primary") | |
submit_status = gr.Textbox(label="Submission Status", lines=4, interactive=False) | |
results_table = gr.DataFrame(label="Submission Results", wrap=True) | |
submit_btn.click( | |
fn=submit_cached_answers, | |
outputs=[submit_status, results_table] | |
) | |
# Clear cache functionality | |
clear_btn.click( | |
fn=clear_cache, | |
outputs=[fetch_status, questions_table] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |