Freddolin commited on
Commit
24816c1
·
verified ·
1 Parent(s): 0af2fce

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +5 -4
agent.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import mimetypes
3
  from transformers import pipeline
4
  from tools.asr_tool import transcribe_audio
5
  from tools.excel_tool import analyze_excel
@@ -14,7 +13,7 @@ class GaiaAgent:
14
  self.llm = pipeline(
15
  "text-generation",
16
  model="mistralai/Mistral-7B-Instruct-v0.2",
17
- token=token,
18
  device="cpu",
19
  max_new_tokens=512,
20
  do_sample=False
@@ -22,29 +21,31 @@ class GaiaAgent:
22
 
23
  def __call__(self, question: str):
24
  trace_log = ""
25
-
26
  if "http" in question or "www." in question:
27
  trace_log += "Detected URL or web reference. Performing search...\n"
28
  result = search_duckduckgo(question)
29
  trace_log += f"Search result: {result}\n"
30
  return result, trace_log
31
 
 
32
  if question.lower().endswith(".mp3") or question.lower().endswith(".wav"):
33
  trace_log += "Detected audio file. Performing transcription...\n"
34
  result = transcribe_audio(question)
35
  trace_log += f"Transcription result: {result}\n"
36
  return result, trace_log
37
 
 
38
  if question.lower().endswith(".xlsx") or question.lower().endswith(".xls"):
39
  trace_log += "Detected Excel file. Performing analysis...\n"
40
  result = analyze_excel(question)
41
  trace_log += f"Excel analysis result: {result}\n"
42
  return result, trace_log
43
 
 
44
  trace_log += "General question. Using local model...\n"
45
  response = self.llm(question)[0]["generated_text"]
46
  trace_log += f"LLM response: {response}\n"
47
-
48
  return response.strip(), trace_log
49
 
50
 
 
1
  import os
 
2
  from transformers import pipeline
3
  from tools.asr_tool import transcribe_audio
4
  from tools.excel_tool import analyze_excel
 
13
  self.llm = pipeline(
14
  "text-generation",
15
  model="mistralai/Mistral-7B-Instruct-v0.2",
16
+ use_auth_token=token,
17
  device="cpu",
18
  max_new_tokens=512,
19
  do_sample=False
 
21
 
22
  def __call__(self, question: str):
23
  trace_log = ""
24
+ # Search tool
25
  if "http" in question or "www." in question:
26
  trace_log += "Detected URL or web reference. Performing search...\n"
27
  result = search_duckduckgo(question)
28
  trace_log += f"Search result: {result}\n"
29
  return result, trace_log
30
 
31
+ # Audio tool
32
  if question.lower().endswith(".mp3") or question.lower().endswith(".wav"):
33
  trace_log += "Detected audio file. Performing transcription...\n"
34
  result = transcribe_audio(question)
35
  trace_log += f"Transcription result: {result}\n"
36
  return result, trace_log
37
 
38
+ # Excel tool
39
  if question.lower().endswith(".xlsx") or question.lower().endswith(".xls"):
40
  trace_log += "Detected Excel file. Performing analysis...\n"
41
  result = analyze_excel(question)
42
  trace_log += f"Excel analysis result: {result}\n"
43
  return result, trace_log
44
 
45
+ # LLM fallback
46
  trace_log += "General question. Using local model...\n"
47
  response = self.llm(question)[0]["generated_text"]
48
  trace_log += f"LLM response: {response}\n"
 
49
  return response.strip(), trace_log
50
 
51