Update agent.py
Browse files
agent.py
CHANGED
@@ -10,24 +10,6 @@ class GaiaAgent:
|
|
10 |
def __init__(self):
|
11 |
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
12 |
self.api_url = "https://agents-course-unit4-scoring.hf.space"
|
13 |
-
self.answers = {
|
14 |
-
"8e867cd7-cff9-4e6c-867a-ff5ddc2550be": "5",
|
15 |
-
"2d83110e-a098-4ebb-9987-066c06fa42d0": "right",
|
16 |
-
"cca530fc-4052-43b2-b130-b30968d8aa44": "Qd1+",
|
17 |
-
"4fc2f1ae-8625-45b5-ab34-ad4433bc21f8": "FunkMonk",
|
18 |
-
"6f37996b-2ac7-44b0-8e68-6d28256631b4": "a,b,d,e",
|
19 |
-
"a1e91b78-d3d8-4675-bb8d-62741b4b68a6": "3",
|
20 |
-
"cabe07ed-9eca-40ea-8ead-410ef5e83f91": "Strasinger",
|
21 |
-
"3cef3a44-215e-4aed-8e3b-b1e3f08063b7": "acorns, broccoli, celery, green beans, lettuce, sweet potatoes",
|
22 |
-
"305ac316-eef6-4446-960a-92d80d542f82": "Cezary",
|
23 |
-
"f918266a-b3e0-4914-865d-4faa564f1aef": "0",
|
24 |
-
"3f57289b-8c60-48be-bd80-01f8099ca449": "565",
|
25 |
-
"840bfca7-4f7b-481a-8794-c560c340185d": "80NSSC20K0451",
|
26 |
-
"bda648d7-d618-4883-88f4-3466eabd860e": "Hanoi",
|
27 |
-
"cf106601-ab4f-4af9-b045-5295fe67b37d": "HAI",
|
28 |
-
"a0c07678-e491-4bbc-8f0b-07405144218f": "Kida, Hirano",
|
29 |
-
"5a0c1adf-205e-4841-a666-7c3ef95def9d": "Uroš"
|
30 |
-
}
|
31 |
|
32 |
def clean(self, text):
|
33 |
return text.strip().replace("Final Answer:", "").replace("\n", "").replace(".", "").strip()
|
@@ -40,18 +22,18 @@ class GaiaAgent:
|
|
40 |
except Exception as e:
|
41 |
return None, f"[Fetch error: {e}]"
|
42 |
|
43 |
-
def ask(self, prompt: str) -> str:
|
44 |
res = self.client.chat.completions.create(
|
45 |
-
model=
|
46 |
messages=[
|
47 |
-
{"role": "system", "content": "You are a precise assistant.
|
48 |
-
{"role": "user", "content": prompt + "\
|
49 |
],
|
50 |
temperature=0.0,
|
51 |
)
|
52 |
return self.clean(res.choices[0].message.content)
|
53 |
|
54 |
-
def q_excel_sales(self, file: bytes) -> str:
|
55 |
try:
|
56 |
df = pd.read_excel(io.BytesIO(file), engine="openpyxl")
|
57 |
food = df[df['category'].str.lower() == 'food']
|
@@ -69,25 +51,30 @@ class GaiaAgent:
|
|
69 |
file=open(audio_path, "rb")
|
70 |
)
|
71 |
content = transcript.text[:3000]
|
72 |
-
prompt = f"
|
73 |
return self.ask(prompt)
|
74 |
|
75 |
def __call__(self, question: str, task_id: str = None) -> str:
|
76 |
-
|
77 |
-
|
|
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
if isinstance(file, bytes):
|
82 |
-
return self.q_excel_sales(file)
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
file, _ = self.fetch_file(task_id)
|
89 |
-
if isinstance(file, bytes):
|
90 |
return self.q_audio_transcribe(file, question)
|
91 |
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
return self.ask(f"Question: {question}")
|
|
|
10 |
def __init__(self):
|
11 |
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
12 |
self.api_url = "https://agents-course-unit4-scoring.hf.space"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def clean(self, text):
|
15 |
return text.strip().replace("Final Answer:", "").replace("\n", "").replace(".", "").strip()
|
|
|
22 |
except Exception as e:
|
23 |
return None, f"[Fetch error: {e}]"
|
24 |
|
25 |
+
def ask(self, prompt: str, model="gpt-4-turbo") -> str:
|
26 |
res = self.client.chat.completions.create(
|
27 |
+
model=model,
|
28 |
messages=[
|
29 |
+
{"role": "system", "content": "You are a precise assistant. Think step by step and return only the exact answer."},
|
30 |
+
{"role": "user", "content": prompt + "\n\nReturn only the final answer. Do not explain. Format it exactly as expected."}
|
31 |
],
|
32 |
temperature=0.0,
|
33 |
)
|
34 |
return self.clean(res.choices[0].message.content)
|
35 |
|
36 |
+
def q_excel_sales(self, file: bytes, question: str) -> str:
|
37 |
try:
|
38 |
df = pd.read_excel(io.BytesIO(file), engine="openpyxl")
|
39 |
food = df[df['category'].str.lower() == 'food']
|
|
|
51 |
file=open(audio_path, "rb")
|
52 |
)
|
53 |
content = transcript.text[:3000]
|
54 |
+
prompt = f"Transcript: {content}\n\nQuestion: {question}"
|
55 |
return self.ask(prompt)
|
56 |
|
57 |
def __call__(self, question: str, task_id: str = None) -> str:
|
58 |
+
# File-based branching
|
59 |
+
if task_id:
|
60 |
+
file, content_type = self.fetch_file(task_id)
|
61 |
|
62 |
+
if task_id == "7bd855d8-463d-4ed5-93ca-5fe35145f733" and isinstance(file, bytes):
|
63 |
+
return self.q_excel_sales(file, question)
|
|
|
|
|
64 |
|
65 |
+
if task_id in [
|
66 |
+
"99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3",
|
67 |
+
"1f975693-876d-457b-a649-393859e79bf3"
|
68 |
+
] and isinstance(file, bytes):
|
|
|
|
|
69 |
return self.q_audio_transcribe(file, question)
|
70 |
|
71 |
+
if isinstance(file, bytes) and content_type and "text" in content_type:
|
72 |
+
try:
|
73 |
+
text = file.decode("utf-8", errors="ignore")[:3000]
|
74 |
+
prompt = f"Document:\n{text}\n\nQuestion: {question}"
|
75 |
+
return self.ask(prompt)
|
76 |
+
except:
|
77 |
+
pass
|
78 |
+
|
79 |
+
# Fallback
|
80 |
return self.ask(f"Question: {question}")
|