Freddolin commited on
Commit
1d782b2
·
verified ·
1 Parent(s): 9bf47dc

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +52 -29
agent.py CHANGED
@@ -1,7 +1,12 @@
1
- # --- agent.py ---
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- from duckduckgo_search import DDGS
4
  import torch
 
 
 
 
 
 
 
 
5
 
6
  SYSTEM_PROMPT = """
7
  You are a general AI assistant. I will ask you a question. Think step by step to find the best possible answer.
@@ -9,52 +14,70 @@ Then return only the answer without any explanation or formatting.
9
  Do not say 'Final answer' or anything else. Just output the raw answer string.
10
  """
11
 
12
- def web_search(query: str, max_results: int = 3) -> list[str]:
13
- results = []
14
- try:
15
- with DDGS() as ddgs:
16
- for r in ddgs.text(query, max_results=max_results):
17
- snippet = f"{r['title']}: {r['body']} (URL: {r['href']})"
18
- results.append(snippet)
19
- except Exception as e:
20
- results.append(f"[Web search error: {e}]")
21
- return results
22
-
23
-
24
  class GaiaAgent:
25
  def __init__(self, model_id="google/flan-t5-base"):
26
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
27
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
28
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  self.model.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- def __call__(self, question: str) -> tuple[str, str]:
32
  try:
33
- # Heuristik: gör webbsök om frågan kräver externa fakta
34
- search_required = any(keyword in question.lower() for keyword in [
35
- "wikipedia", "who", "when", "where", "youtube", "mp3", "video", "article", "name", "code", "city", "award", "nasa"
36
- ])
37
-
38
- if search_required:
39
- search_results = web_search(question)
40
- context = "\n".join(search_results)
41
- prompt = f"{SYSTEM_PROMPT}\n\nSearch context:\n{context}\n\nQuestion: {question}"
42
- trace = f"Search used:\n{context}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  else:
44
  prompt = f"{SYSTEM_PROMPT}\n\n{question}"
45
- trace = "Search not used."
46
 
47
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
48
  outputs = self.model.generate(
49
  **inputs,
50
  max_new_tokens=128,
51
  do_sample=False,
 
52
  pad_token_id=self.tokenizer.pad_token_id
53
  )
54
  output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
55
  final = output_text.strip()
56
- return final, trace
57
-
58
  except Exception as e:
59
  return "ERROR", f"Agent failed: {e}"
60
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ from ddgs import DDGS
4
+ import re
5
+ import pandas as pd
6
+ import tempfile
7
+ import os
8
+ import whisper
9
+
10
 
11
  SYSTEM_PROMPT = """
12
  You are a general AI assistant. I will ask you a question. Think step by step to find the best possible answer.
 
14
  Do not say 'Final answer' or anything else. Just output the raw answer string.
15
  """
16
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class GaiaAgent:
18
  def __init__(self, model_id="google/flan-t5-base"):
19
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
20
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
21
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  self.model.to(self.device)
23
+ self.transcriber = whisper.load_model("base")
24
+
25
+ def search(self, query: str) -> str:
26
+ try:
27
+ with DDGS() as ddgs:
28
+ results = list(ddgs.text(query, safesearch="off"))
29
+ if results:
30
+ return results[0]['body']
31
+ except Exception as e:
32
+ return f"Search failed: {e}"
33
+ return ""
34
 
35
+ def transcribe_audio(self, file_path: str) -> str:
36
  try:
37
+ result = self.transcriber.transcribe(file_path)
38
+ return result['text']
39
+ except Exception as e:
40
+ return f"Audio transcription failed: {e}"
41
+
42
+ def handle_excel(self, file_path: str) -> str:
43
+ try:
44
+ df = pd.read_excel(file_path)
45
+ food_sales = df[df['Category'].str.lower() != 'drink']['Sales'].sum()
46
+ return f"{food_sales:.2f}"
47
+ except Exception as e:
48
+ return f"Excel parsing failed: {e}"
49
+
50
+ def __call__(self, question: str, files: dict = None) -> tuple[str, str]:
51
+ try:
52
+ if "http" in question or "Wikipedia" in question:
53
+ web_context = self.search(question)
54
+ prompt = f"{SYSTEM_PROMPT}\n\n{web_context}\n\nQuestion: {question}"
55
+ elif files:
56
+ file_keys = list(files.keys())
57
+ for key in file_keys:
58
+ if key.endswith(".mp3"):
59
+ audio_txt = self.transcribe_audio(files[key])
60
+ prompt = f"{SYSTEM_PROMPT}\n\n{audio_txt}\n\n{question}"
61
+ break
62
+ elif key.endswith(".xlsx"):
63
+ excel_result = self.handle_excel(files[key])
64
+ return excel_result, excel_result
65
+ else:
66
+ prompt = f"{SYSTEM_PROMPT}\n\n{question}"
67
  else:
68
  prompt = f"{SYSTEM_PROMPT}\n\n{question}"
 
69
 
70
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
71
  outputs = self.model.generate(
72
  **inputs,
73
  max_new_tokens=128,
74
  do_sample=False,
75
+ temperature=0.0,
76
  pad_token_id=self.tokenizer.pad_token_id
77
  )
78
  output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
79
  final = output_text.strip()
80
+ return final, output_text
 
81
  except Exception as e:
82
  return "ERROR", f"Agent failed: {e}"
83