File size: 4,668 Bytes
e70ca94 fde864c e70ca94 1d8f1ea ed680f1 e70ca94 33e3bfc e70ca94 33e3bfc e70ca94 33e3bfc e70ca94 33e3bfc e70ca94 ed680f1 e70ca94 33e3bfc e70ca94 33e3bfc e70ca94 33e3bfc e70ca94 33e3bfc e70ca94 ed680f1 e70ca94 33e3bfc e70ca94 33e3bfc e70ca94 33e3bfc e70ca94 33e3bfc e70ca94 fde864c e70ca94 33e3bfc e70ca94 ed680f1 e70ca94 ed680f1 e70ca94 ed680f1 e70ca94 7ec5a35 e70ca94 4044d5c ed680f1 e70ca94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# agent.py
import os
import requests
from smolagents import LiteLLMModel, CodeAgent, tool, DuckDuckGoSearchTool, SpeechToTextTool, VisitWebpageTool
import speech_recognition as sr
from pydub import AudioSegment
from PIL import Image
# Ustaw endpoint API (dostosuj jeśli inny)
api_url = "https://agents-course-unit4-scoring.hf.space"
# ==== Narzędzia własne do podpięcia ====
@tool
def download_question_file(task_id: str, file_name: str = "", save_dir: str = ".") -> str:
"""
Downloads the file associated with a given task ID and saves it to disk.
Args:
task_id (str): Unique question/task identifier.
file_name (str): Optional file name.
save_dir (str): Directory to save.
Returns:
str: Path to the saved file, or error.
"""
url = f"{api_url}/files/{task_id}"
try:
resp = requests.get(url, timeout=15)
resp.raise_for_status()
except requests.exceptions.HTTPError as e:
return f"HTTP error: {e.response.status_code}"
except Exception as e:
return f"Network error: {e}"
content_disposition = resp.headers.get("Content-Disposition", "")
filename = (
content_disposition.split('filename="')[-1].rstrip('"')
if "filename=" in content_disposition
else file_name if file_name else f"{task_id}.dat"
)
os.makedirs(save_dir, exist_ok=True)
file_path = os.path.join(save_dir, filename)
with open(file_path, "wb") as f:
f.write(resp.content)
return file_path
@tool
def read_image(image_path: str) -> Image:
"""
Loads image from disk.
Args:
image_path (str): Path to the image file.
Returns:
The image.
"""
return Image.open(image_path)
@tool
def audio_to_text(audio_path: str) -> str:
"""
Converts audio (mp3/wav) to text using Google Speech Recognition.
Args:
audio_path (str): Path to the audio file.
Returns:
str: Recognized text.
"""
if audio_path.endswith(".mp3"):
source_file = audio_path.replace(".mp3", ".wav")
sound = AudioSegment.from_mp3(audio_path)
sound.export(source_file, format="wav")
else:
source_file = audio_path
r = sr.Recognizer()
audio_file = sr.AudioFile(source_file)
with audio_file as source:
audio = r.record(source)
text = r.recognize_google(audio)
return text
@tool
def extract_text_from_image(image_path: str) -> str:
"""
Extract text from image using pytesseract (OCR).
Args:
image_path: Path to the image file.
Returns:
Extracted text or error message.
"""
try:
import pytesseract
from PIL import Image
image = Image.open(image_path)
text = pytesseract.image_to_string(image)
return text
except ImportError:
return "Error: pytesseract is not installed."
except Exception as e:
return f"Error extracting text from image: {str(e)}"
# ==== AGENT ====
class GaiaAgent:
def __init__(self, model=None, max_steps=8):
# Jeśli model nie został przekazany, inicjalizuj domyślnie na OpenAI GPT-4o (lub inny)
if model is None:
api_key = os.getenv("OPENAI_API_KEY", "")
model = LiteLLMModel(
model_id="gpt-4o", # Zmień na swój model jeśli potrzeba
api_key=api_key,
)
self.gaia_agent = CodeAgent(
model=model,
tools=[
DuckDuckGoSearchTool(),
download_question_file,
read_image,
audio_to_text,
extract_text_from_image,
VisitWebpageTool(),
SpeechToTextTool()
],
additional_authorized_imports=["pandas", "numpy", "math", "statistics", "scipy"],
max_steps=max_steps
)
# Możesz dodać tu dodatkową konfigurację promptów jeśli chcesz.
def __call__(self, question: str) -> str:
print(f"Agent received question (first 50 chars): {question[:50]}...")
if self.gaia_agent:
try:
answer = self.gaia_agent.run(question)
print(f"Agent generated answer: {answer[:50]}..." if len(answer) > 50 else f"Agent generated answer: {answer}")
return answer
except Exception as e:
print(f"Error processing question: {e}")
return "An error occurred while processing your question. Please check the agent logs for details."
else:
return "The agent is not properly initialized. Please check your API keys and configuration." |