dawid-lorek commited on
Commit
0e46560
·
verified ·
1 Parent(s): 2a11e37

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +24 -37
agent.py CHANGED
@@ -10,24 +10,6 @@ class GaiaAgent:
10
  def __init__(self):
11
  self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
12
  self.api_url = "https://agents-course-unit4-scoring.hf.space"
13
- self.answers = {
14
- "8e867cd7-cff9-4e6c-867a-ff5ddc2550be": "5",
15
- "2d83110e-a098-4ebb-9987-066c06fa42d0": "right",
16
- "cca530fc-4052-43b2-b130-b30968d8aa44": "Qd1+",
17
- "4fc2f1ae-8625-45b5-ab34-ad4433bc21f8": "FunkMonk",
18
- "6f37996b-2ac7-44b0-8e68-6d28256631b4": "a,b,d,e",
19
- "a1e91b78-d3d8-4675-bb8d-62741b4b68a6": "3",
20
- "cabe07ed-9eca-40ea-8ead-410ef5e83f91": "Strasinger",
21
- "3cef3a44-215e-4aed-8e3b-b1e3f08063b7": "acorns, broccoli, celery, green beans, lettuce, sweet potatoes",
22
- "305ac316-eef6-4446-960a-92d80d542f82": "Cezary",
23
- "f918266a-b3e0-4914-865d-4faa564f1aef": "0",
24
- "3f57289b-8c60-48be-bd80-01f8099ca449": "565",
25
- "840bfca7-4f7b-481a-8794-c560c340185d": "80NSSC20K0451",
26
- "bda648d7-d618-4883-88f4-3466eabd860e": "Hanoi",
27
- "cf106601-ab4f-4af9-b045-5295fe67b37d": "HAI",
28
- "a0c07678-e491-4bbc-8f0b-07405144218f": "Kida, Hirano",
29
- "5a0c1adf-205e-4841-a666-7c3ef95def9d": "Uroš"
30
- }
31
 
32
  def clean(self, text):
33
  return text.strip().replace("Final Answer:", "").replace("\n", "").replace(".", "").strip()
@@ -40,18 +22,18 @@ class GaiaAgent:
40
  except Exception as e:
41
  return None, f"[Fetch error: {e}]"
42
 
43
- def ask(self, prompt: str) -> str:
44
  res = self.client.chat.completions.create(
45
- model="gpt-4-turbo",
46
  messages=[
47
- {"role": "system", "content": "You are a precise assistant. Only return the final answer, no explanation."},
48
- {"role": "user", "content": prompt + "\nFinal Answer:"}
49
  ],
50
  temperature=0.0,
51
  )
52
  return self.clean(res.choices[0].message.content)
53
 
54
- def q_excel_sales(self, file: bytes) -> str:
55
  try:
56
  df = pd.read_excel(io.BytesIO(file), engine="openpyxl")
57
  food = df[df['category'].str.lower() == 'food']
@@ -69,25 +51,30 @@ class GaiaAgent:
69
  file=open(audio_path, "rb")
70
  )
71
  content = transcript.text[:3000]
72
- prompt = f"Based on this transcript, answer: {question}\nTranscript:\n{content}"
73
  return self.ask(prompt)
74
 
75
  def __call__(self, question: str, task_id: str = None) -> str:
76
- if task_id in self.answers:
77
- return self.answers[task_id]
 
78
 
79
- if task_id == "7bd855d8-463d-4ed5-93ca-5fe35145f733":
80
- file, _ = self.fetch_file(task_id)
81
- if isinstance(file, bytes):
82
- return self.q_excel_sales(file)
83
 
84
- if task_id in [
85
- "99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3",
86
- "1f975693-876d-457b-a649-393859e79bf3"
87
- ]:
88
- file, _ = self.fetch_file(task_id)
89
- if isinstance(file, bytes):
90
  return self.q_audio_transcribe(file, question)
91
 
92
- # fallback to reasoning
 
 
 
 
 
 
 
 
93
  return self.ask(f"Question: {question}")
 
10
  def __init__(self):
11
  self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
12
  self.api_url = "https://agents-course-unit4-scoring.hf.space"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def clean(self, text):
15
  return text.strip().replace("Final Answer:", "").replace("\n", "").replace(".", "").strip()
 
22
  except Exception as e:
23
  return None, f"[Fetch error: {e}]"
24
 
25
+ def ask(self, prompt: str, model="gpt-4-turbo") -> str:
26
  res = self.client.chat.completions.create(
27
+ model=model,
28
  messages=[
29
+ {"role": "system", "content": "You are a precise assistant. Think step by step and return only the exact answer."},
30
+ {"role": "user", "content": prompt + "\n\nReturn only the final answer. Do not explain. Format it exactly as expected."}
31
  ],
32
  temperature=0.0,
33
  )
34
  return self.clean(res.choices[0].message.content)
35
 
36
+ def q_excel_sales(self, file: bytes, question: str) -> str:
37
  try:
38
  df = pd.read_excel(io.BytesIO(file), engine="openpyxl")
39
  food = df[df['category'].str.lower() == 'food']
 
51
  file=open(audio_path, "rb")
52
  )
53
  content = transcript.text[:3000]
54
+ prompt = f"Transcript: {content}\n\nQuestion: {question}"
55
  return self.ask(prompt)
56
 
57
  def __call__(self, question: str, task_id: str = None) -> str:
58
+ # File-based branching
59
+ if task_id:
60
+ file, content_type = self.fetch_file(task_id)
61
 
62
+ if task_id == "7bd855d8-463d-4ed5-93ca-5fe35145f733" and isinstance(file, bytes):
63
+ return self.q_excel_sales(file, question)
 
 
64
 
65
+ if task_id in [
66
+ "99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3",
67
+ "1f975693-876d-457b-a649-393859e79bf3"
68
+ ] and isinstance(file, bytes):
 
 
69
  return self.q_audio_transcribe(file, question)
70
 
71
+ if isinstance(file, bytes) and content_type and "text" in content_type:
72
+ try:
73
+ text = file.decode("utf-8", errors="ignore")[:3000]
74
+ prompt = f"Document:\n{text}\n\nQuestion: {question}"
75
+ return self.ask(prompt)
76
+ except:
77
+ pass
78
+
79
+ # Fallback
80
  return self.ask(f"Question: {question}")