dawid-lorek commited on
Commit
239dbcb
·
verified ·
1 Parent(s): eca84dc

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +53 -103
agent.py CHANGED
@@ -1,106 +1,56 @@
1
  import os
2
- import re
3
- import json
4
- import pandas as pd
5
- import tempfile
6
- import openpyxl
7
- import whisper
8
-
9
- from llama_index.llms.openai import OpenAI
10
- from llama_index.core.agent import FunctionCallingAgent
11
- from llama_index.core.tools import FunctionTool
12
-
13
- # === TOOL FUNCTIONS ===
14
-
15
- def reverse_sentence(sentence: str) -> str:
16
- """Reverse a sentence character by character."""
17
- return sentence[::-1]
18
-
19
- def extract_vegetables_from_list(grocery_list: str) -> str:
20
- """Extract botanically valid vegetables from comma-separated list."""
21
- known_vegetables = {
22
- "broccoli", "celery", "green beans", "lettuce", "sweet potatoes"
23
- }
24
- items = [item.strip().lower() for item in grocery_list.split(",")]
25
- vegetables = sorted(set(filter(lambda x: x in known_vegetables, items)))
26
- return ", ".join(vegetables)
27
-
28
- def commutative_subset_hint(_: str) -> str:
29
- """Static helper for commutative subset fallback."""
30
- return "a, b, c"
31
-
32
- def convert_table_if_detected(question: str, file_context: str) -> str:
33
- """If question contains a table about * on set S, try parsing non-commutative set."""
34
- if "* on the set" in question and file_context:
35
  try:
36
- table_lines = [
37
- line.strip()
38
- for line in file_context.splitlines()
39
- if '|' in line and '*' not in line[:2]
40
- ]
41
- headers = re.split(r'\|+', table_lines[0])[1:-1]
42
- data_rows = [re.split(r'\|+', row)[1:-1] for row in table_lines[1:]]
43
- index = [row[0] for row in data_rows]
44
- matrix = [row[1:] for row in data_rows]
45
- df = pd.DataFrame(matrix, index=index, columns=headers)
46
- non_comm = set()
47
- for a in df.index:
48
- for b in df.columns:
49
- if df.at[a, b] != df.at[b, a]:
50
- non_comm.add(a)
51
- non_comm.add(b)
52
- result = ", ".join(sorted(non_comm))
53
- file_context += f" [Parsed Non-Commutative Set] {result}"
54
  except Exception as e:
55
- file_context += f" [Table Parse Error] {e}"
56
- return file_context
57
-
58
- def transcribe_audio(file_path: str) -> str:
59
- """Transcribe audio file using OpenAI Whisper."""
60
- model = whisper.load_model("base")
61
- result = model.transcribe(file_path)
62
- return result['text']
63
-
64
- def extract_excel_total_food_sales(file_path: str) -> str:
65
- """Extract total food sales from Excel file."""
66
- wb = openpyxl.load_workbook(file_path)
67
- sheet = wb.active
68
- total = 0
69
- for row in sheet.iter_rows(min_row=2, values_only=True):
70
- category, amount = row[1], row[2]
71
- if isinstance(category, str) and 'food' in category.lower():
72
- total += float(amount)
73
- return f"${total:.2f}"
74
-
75
- # === LLM SETUP ===
76
- llm = OpenAI(model="gpt-4o")
77
-
78
- # === TOOLS ===
79
- tools = [
80
- FunctionTool.from_defaults(fn=reverse_sentence),
81
- FunctionTool.from_defaults(fn=extract_vegetables_from_list),
82
- FunctionTool.from_defaults(fn=commutative_subset_hint),
83
- ]
84
-
85
- agent = FunctionCallingAgent.from_tools(
86
- tools=tools,
87
- llm=llm,
88
- system_prompt=(
89
- "You are a strict and factual research agent solving GAIA benchmark questions. "
90
- "You must answer precisely, based only on available information. "
91
- "Never hallucinate, and always return concise, well-formatted answers. "
92
- "Use tools where necessary, and return plain text only — no extra explanation."
93
- ),
94
- verbose=True
95
- )
96
-
97
- # === MAIN AGENT CALL ===
98
- def answer_question(question: str, task_id: str = None, file_content: str = "") -> str:
99
- file_context = file_content or ""
100
- file_context = convert_table_if_detected(question, file_context)
101
-
102
- try:
103
- response = agent.get_response_sync(question)
104
- return response.text if hasattr(response, "text") else str(response)
105
- except Exception as e:
106
- return f"[ERROR] {e}"
 
1
  import os
2
+ import requests
3
+ from openai import OpenAI
4
+
5
+ class GaiaAgent:
6
+ def __init__(self):
7
+ self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
8
+ self.instructions = (
9
+ "You are a top-tier research assistant for the GAIA benchmark. "
10
+ "You analyze documents, reason step by step, and always provide a single, concise, and correct answer. "
11
+ "If a file is provided, extract all relevant information. Use only information from the question and file. "
12
+ "Show your reasoning before the answer, but end with 'Final Answer: <your answer>'."
13
+ )
14
+ self.api_url = "https://agents-course-unit4-scoring.hf.space"
15
+
16
+ def fetch_file_content(self, task_id: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  try:
18
+ url = f"{self.api_url}/files/{task_id}"
19
+ response = requests.get(url, timeout=15)
20
+ response.raise_for_status()
21
+
22
+ content_type = response.headers.get("Content-Type", "")
23
+ if any(t in content_type for t in ["text", "csv", "json"]):
24
+ return response.text[:6000] # Allow more context for better answers
25
+ elif "application/pdf" in content_type:
26
+ return "[PDF file detected. Use a PDF parser to extract text.]"
27
+ else:
28
+ return f"[Unsupported file type: {content_type}]"
 
 
 
 
 
 
 
29
  except Exception as e:
30
+ return f"[Error downloading or reading file: {e}]"
31
+
32
+ def __call__(self, question: str, task_id: str = None) -> str:
33
+ file_context = ""
34
+ if task_id:
35
+ file_context = self.fetch_file_content(task_id)
36
+ if file_context:
37
+ file_context = f"Here is the related file content:\n{file_context}\n"
38
+
39
+ prompt = (
40
+ f"{self.instructions}\n\n"
41
+ f"{file_context}"
42
+ f"Question: {question}\n"
43
+ "Show your reasoning step by step, then provide the final answer as 'Final Answer: <answer>'."
44
+ )
45
+
46
+ response = self.client.chat.completions.create(
47
+ model="gpt-4o", # Use the latest, most capable model for better accuracy
48
+ messages=[
49
+ {"role": "system", "content": self.instructions},
50
+ {"role": "user", "content": prompt}
51
+ ],
52
+ temperature=0.0,
53
+ max_tokens=1024,
54
+ )
55
+
56
+ return response.choices[0].message.content.strip()