Freddolin commited on
Commit
13755f8
·
verified ·
1 Parent(s): f931a9d

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +34 -22
agent.py CHANGED
@@ -1,38 +1,50 @@
 
 
1
  from transformers import pipeline
2
- from tools.search_tool import search_duckduckgo
3
  from tools.asr_tool import transcribe_audio
4
  from tools.excel_tool import analyze_excel
 
5
 
6
  class GaiaAgent:
7
  def __init__(self):
8
- self.llm = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", device="cpu")
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def __call__(self, question: str):
11
- trace = ""
12
- question_lower = question.lower()
13
 
14
- if "http" in question_lower or "wikipedia" in question_lower:
15
- trace += "Detected URL or web reference. Performing search...\n"
16
  result = search_duckduckgo(question)
17
- trace += result
18
- return result.strip(), trace
19
 
20
- elif question_lower.endswith(".mp3"):
21
- trace += "Detected audio reference. Transcribing audio...\n"
22
  result = transcribe_audio(question)
23
- trace += result
24
- return result.strip(), trace
25
 
26
- elif question_lower.endswith(".xlsx") or question_lower.endswith(".csv"):
27
- trace += "Detected Excel file. Analyzing...\n"
28
  result = analyze_excel(question)
29
- trace += result
30
- return result.strip(), trace
31
-
32
- else:
33
- trace += "General question. Using local model...\n"
34
- output = self.llm(question, max_new_tokens=256, do_sample=False)[0]["generated_text"]
35
- trace += output
36
- return output.strip(), trace
37
 
38
 
 
1
+ import os
2
+ import mimetypes
3
  from transformers import pipeline
 
4
  from tools.asr_tool import transcribe_audio
5
  from tools.excel_tool import analyze_excel
6
+ from tools.search_tool import search_duckduckgo
7
 
8
  class GaiaAgent:
9
  def __init__(self):
10
+ token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
11
+ if not token:
12
+ raise ValueError("Missing HUGGINGFACEHUB_API_TOKEN environment variable.")
13
+
14
+ self.llm = pipeline(
15
+ "text-generation",
16
+ model="mistralai/Mistral-7B-Instruct-v0.2",
17
+ token=token,
18
+ device="cpu",
19
+ max_new_tokens=512,
20
+ do_sample=False
21
+ )
22
 
23
  def __call__(self, question: str):
24
+ trace_log = ""
 
25
 
26
+ if "http" in question or "www." in question:
27
+ trace_log += "Detected URL or web reference. Performing search...\n"
28
  result = search_duckduckgo(question)
29
+ trace_log += f"Search result: {result}\n"
30
+ return result, trace_log
31
 
32
+ if question.lower().endswith(".mp3") or question.lower().endswith(".wav"):
33
+ trace_log += "Detected audio file. Performing transcription...\n"
34
  result = transcribe_audio(question)
35
+ trace_log += f"Transcription result: {result}\n"
36
+ return result, trace_log
37
 
38
+ if question.lower().endswith(".xlsx") or question.lower().endswith(".xls"):
39
+ trace_log += "Detected Excel file. Performing analysis...\n"
40
  result = analyze_excel(question)
41
+ trace_log += f"Excel analysis result: {result}\n"
42
+ return result, trace_log
43
+
44
+ trace_log += "General question. Using local model...\n"
45
+ response = self.llm(question)[0]["generated_text"]
46
+ trace_log += f"LLM response: {response}\n"
47
+
48
+ return response.strip(), trace_log
49
 
50