Freddolin commited on
Commit
2477b72
·
verified ·
1 Parent(s): d0e25da

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +23 -28
agent.py CHANGED
@@ -7,11 +7,8 @@ import tempfile
7
  import os
8
  import whisper
9
 
10
-
11
  SYSTEM_PROMPT = """
12
- You are a general AI assistant. I will ask you a question. Think step by step to find the best possible answer.
13
- Then return only the answer without any explanation or formatting.
14
- Do not say 'Final answer' or anything else. Just output the raw answer string.
15
  """
16
 
17
  class GaiaAgent:
@@ -29,44 +26,42 @@ class GaiaAgent:
29
  if results:
30
  return results[0]['body']
31
  except Exception as e:
32
- return f"Search failed: {e}"
33
  return ""
34
 
35
  def transcribe_audio(self, file_path: str) -> str:
36
  try:
37
  result = self.transcriber.transcribe(file_path)
38
  return result['text']
39
- except Exception as e:
40
- return f"Audio transcription failed: {e}"
41
 
42
  def handle_excel(self, file_path: str) -> str:
43
  try:
44
  df = pd.read_excel(file_path)
45
- food_sales = df[df['Category'].str.lower() != 'drink']['Sales'].sum()
46
- return f"{food_sales:.2f}"
47
- except Exception as e:
48
- return f"Excel parsing failed: {e}"
 
 
 
49
 
50
  def __call__(self, question: str, files: dict = None) -> tuple[str, str]:
51
  try:
52
- if "http" in question or "Wikipedia" in question:
53
- web_context = self.search(question)
54
- prompt = f"{SYSTEM_PROMPT}\n\n{web_context}\n\nQuestion: {question}"
55
- elif files:
56
- file_keys = list(files.keys())
57
- for key in file_keys:
58
- if key.endswith(".mp3"):
59
- audio_txt = self.transcribe_audio(files[key])
60
- prompt = f"{SYSTEM_PROMPT}\n\n{audio_txt}\n\n{question}"
61
  break
62
- elif key.endswith(".xlsx"):
63
- excel_result = self.handle_excel(files[key])
64
- return excel_result, excel_result
65
- else:
66
- prompt = f"{SYSTEM_PROMPT}\n\n{question}"
67
- else:
68
- prompt = f"{SYSTEM_PROMPT}\n\n{question}"
69
 
 
70
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
71
  outputs = self.model.generate(
72
  **inputs,
@@ -77,7 +72,7 @@ class GaiaAgent:
77
  )
78
  output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
79
  final = output_text.strip()
80
- return final, output_text
81
  except Exception as e:
82
  return "ERROR", f"Agent failed: {e}"
83
 
 
7
  import os
8
  import whisper
9
 
 
10
  SYSTEM_PROMPT = """
11
+ You are a helpful AI assistant. Think step by step to solve the problem. If the question requires reasoning, perform it. If it refers to a search or file, use the result provided. At the end, return ONLY the final answer string. No explanations.
 
 
12
  """
13
 
14
  class GaiaAgent:
 
26
  if results:
27
  return results[0]['body']
28
  except Exception as e:
29
+ return ""
30
  return ""
31
 
32
  def transcribe_audio(self, file_path: str) -> str:
33
  try:
34
  result = self.transcriber.transcribe(file_path)
35
  return result['text']
36
+ except Exception:
37
+ return ""
38
 
39
  def handle_excel(self, file_path: str) -> str:
40
  try:
41
  df = pd.read_excel(file_path)
42
+ df.columns = [col.lower() for col in df.columns]
43
+ if 'category' in df.columns and 'sales' in df.columns:
44
+ food_sales = df[df['category'].str.lower() != 'drink']['sales'].sum()
45
+ return f"{food_sales:.2f}"
46
+ except Exception:
47
+ return ""
48
+ return ""
49
 
50
  def __call__(self, question: str, files: dict = None) -> tuple[str, str]:
51
  try:
52
+ context = ""
53
+ if files:
54
+ for filename, filepath in files.items():
55
+ if filename.endswith(".mp3"):
56
+ context = self.transcribe_audio(filepath)
 
 
 
 
57
  break
58
+ elif filename.endswith(".xlsx"):
59
+ excel_result = self.handle_excel(filepath)
60
+ return excel_result.strip(), excel_result.strip()
61
+ elif "http" in question.lower() or "wikipedia" in question.lower():
62
+ context = self.search(question)
 
 
63
 
64
+ prompt = f"{SYSTEM_PROMPT}\n\n{context}\n\nQuestion: {question.strip()}"
65
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
66
  outputs = self.model.generate(
67
  **inputs,
 
72
  )
73
  output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
74
  final = output_text.strip()
75
+ return final, final
76
  except Exception as e:
77
  return "ERROR", f"Agent failed: {e}"
78