dawid-lorek's picture
Update agent.py
e70ca94 verified
raw
history blame
4.67 kB
# agent.py
import os
import requests
from smolagents import LiteLLMModel, CodeAgent, tool, DuckDuckGoSearchTool, SpeechToTextTool, VisitWebpageTool
import speech_recognition as sr
from pydub import AudioSegment
from PIL import Image
# Ustaw endpoint API (dostosuj jeśli inny)
api_url = "https://agents-course-unit4-scoring.hf.space"
# ==== Narzędzia własne do podpięcia ====
@tool
def download_question_file(task_id: str, file_name: str = "", save_dir: str = ".") -> str:
"""
Downloads the file associated with a given task ID and saves it to disk.
Args:
task_id (str): Unique question/task identifier.
file_name (str): Optional file name.
save_dir (str): Directory to save.
Returns:
str: Path to the saved file, or error.
"""
url = f"{api_url}/files/{task_id}"
try:
resp = requests.get(url, timeout=15)
resp.raise_for_status()
except requests.exceptions.HTTPError as e:
return f"HTTP error: {e.response.status_code}"
except Exception as e:
return f"Network error: {e}"
content_disposition = resp.headers.get("Content-Disposition", "")
filename = (
content_disposition.split('filename="')[-1].rstrip('"')
if "filename=" in content_disposition
else file_name if file_name else f"{task_id}.dat"
)
os.makedirs(save_dir, exist_ok=True)
file_path = os.path.join(save_dir, filename)
with open(file_path, "wb") as f:
f.write(resp.content)
return file_path
@tool
def read_image(image_path: str) -> Image:
"""
Loads image from disk.
Args:
image_path (str): Path to the image file.
Returns:
The image.
"""
return Image.open(image_path)
@tool
def audio_to_text(audio_path: str) -> str:
"""
Converts audio (mp3/wav) to text using Google Speech Recognition.
Args:
audio_path (str): Path to the audio file.
Returns:
str: Recognized text.
"""
if audio_path.endswith(".mp3"):
source_file = audio_path.replace(".mp3", ".wav")
sound = AudioSegment.from_mp3(audio_path)
sound.export(source_file, format="wav")
else:
source_file = audio_path
r = sr.Recognizer()
audio_file = sr.AudioFile(source_file)
with audio_file as source:
audio = r.record(source)
text = r.recognize_google(audio)
return text
@tool
def extract_text_from_image(image_path: str) -> str:
"""
Extract text from image using pytesseract (OCR).
Args:
image_path: Path to the image file.
Returns:
Extracted text or error message.
"""
try:
import pytesseract
from PIL import Image
image = Image.open(image_path)
text = pytesseract.image_to_string(image)
return text
except ImportError:
return "Error: pytesseract is not installed."
except Exception as e:
return f"Error extracting text from image: {str(e)}"
# ==== AGENT ====
class GaiaAgent:
def __init__(self, model=None, max_steps=8):
# Jeśli model nie został przekazany, inicjalizuj domyślnie na OpenAI GPT-4o (lub inny)
if model is None:
api_key = os.getenv("OPENAI_API_KEY", "")
model = LiteLLMModel(
model_id="gpt-4o", # Zmień na swój model jeśli potrzeba
api_key=api_key,
)
self.gaia_agent = CodeAgent(
model=model,
tools=[
DuckDuckGoSearchTool(),
download_question_file,
read_image,
audio_to_text,
extract_text_from_image,
VisitWebpageTool(),
SpeechToTextTool()
],
additional_authorized_imports=["pandas", "numpy", "math", "statistics", "scipy"],
max_steps=max_steps
)
# Możesz dodać tu dodatkową konfigurację promptów jeśli chcesz.
def __call__(self, question: str) -> str:
print(f"Agent received question (first 50 chars): {question[:50]}...")
if self.gaia_agent:
try:
answer = self.gaia_agent.run(question)
print(f"Agent generated answer: {answer[:50]}..." if len(answer) > 50 else f"Agent generated answer: {answer}")
return answer
except Exception as e:
print(f"Error processing question: {e}")
return "An error occurred while processing your question. Please check the agent logs for details."
else:
return "The agent is not properly initialized. Please check your API keys and configuration."