Freddolin commited on
Commit
c36be6c
·
verified ·
1 Parent(s): c3c803a

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +26 -31
agent.py CHANGED
@@ -1,43 +1,38 @@
1
  from transformers import pipeline
 
2
  from tools.asr_tool import transcribe_audio
3
  from tools.excel_tool import analyze_excel
4
- from tools.search_tool import search_duckduckgo
5
- import mimetypes
6
-
7
 
8
  class GaiaAgent:
9
  def __init__(self):
10
- print("Loading model...")
11
- self.model = pipeline(
12
- "text2text-generation",
13
- model="MBZUAI/LaMini-Flan-T5-783M",
14
- tokenizer="MBZUAI/LaMini-Flan-T5-783M"
15
- )
16
- print("Model loaded.")
17
 
18
- def __call__(self, query):
19
  trace = ""
20
- final_answer = ""
21
-
22
- # Försök identifiera om det är en filreferens
23
- if isinstance(query, str) and (query.endswith(".mp3") or query.endswith(".wav")):
24
- trace = "Detected audio file. Transcribing..."
25
- final_answer = transcribe_audio(query)
26
-
27
- elif isinstance(query, str) and (query.endswith(".xls") or query.endswith(".xlsx")):
28
- trace = "Detected Excel file. Analyzing..."
29
- final_answer = analyze_excel(query)
30
-
31
- elif "http" in query:
32
- trace = "Detected URL or web reference. Performing search..."
33
- final_answer = search_duckduckgo(query)
 
 
 
 
 
34
 
35
  else:
36
- trace = "General question. Using local model..."
37
- output = self.model(query, max_new_tokens=128)
38
- final_answer = output[0]["generated_text"].strip()
39
-
40
- return final_answer, trace
41
-
42
 
43
 
 
1
  from transformers import pipeline
2
+ from tools.search_tool import search_duckduckgo
3
  from tools.asr_tool import transcribe_audio
4
  from tools.excel_tool import analyze_excel
 
 
 
5
 
6
  class GaiaAgent:
7
  def __init__(self):
8
+ self.llm = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", device="cpu")
 
 
 
 
 
 
9
 
10
+ def __call__(self, question: str):
11
  trace = ""
12
+ question_lower = question.lower()
13
+
14
+ if "http" in question_lower or "wikipedia" in question_lower:
15
+ trace += "Detected URL or web reference. Performing search...\n"
16
+ result = search_duckduckgo(question)
17
+ trace += result
18
+ return result.strip(), trace
19
+
20
+ elif question_lower.endswith(".mp3"):
21
+ trace += "Detected audio reference. Transcribing audio...\n"
22
+ result = transcribe_audio(question)
23
+ trace += result
24
+ return result.strip(), trace
25
+
26
+ elif question_lower.endswith(".xlsx") or question_lower.endswith(".csv"):
27
+ trace += "Detected Excel file. Analyzing...\n"
28
+ result = analyze_excel(question)
29
+ trace += result
30
+ return result.strip(), trace
31
 
32
  else:
33
+ trace += "General question. Using local model...\n"
34
+ output = self.llm(question, max_new_tokens=256, do_sample=False)[0]["generated_text"]
35
+ trace += output
36
+ return output.strip(), trace
 
 
37
 
38