Freddolin commited on
Commit
058b8cf
·
verified ·
1 Parent(s): 94aaa31

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +33 -27
agent.py CHANGED
@@ -1,40 +1,46 @@
 
1
  from transformers import pipeline
2
  from tools.asr_tool import transcribe_audio
3
  from tools.excel_tool import analyze_excel
4
  from tools.search_tool import search_duckduckgo
5
- import mimetypes
6
 
7
  class GaiaAgent:
8
  def __init__(self):
9
  print("Loading model...")
10
- self.qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")
 
 
 
 
11
 
12
- def __call__(self, question: str):
13
- trace = ""
 
 
 
 
 
 
 
 
 
14
 
15
- # Handle audio
16
- if question.lower().strip().endswith(('.mp3', '.wav')):
17
- trace += "Audio detected. Running transcription...\n"
18
- text = transcribe_audio(question.strip())
19
- trace += f"Transcribed text: {text}\n"
20
- answer = self.qa_pipeline(text, max_new_tokens=64)[0]['generated_text']
21
- return answer.strip(), trace
22
 
23
- # Handle Excel
24
- if question.lower().strip().endswith(('.xls', '.xlsx')):
25
- trace += "Excel detected. Running analysis...\n"
26
- answer = analyze_excel(question.strip())
27
- trace += f"Extracted value: {answer}\n"
28
- return answer.strip(), trace
 
 
 
29
 
30
- # Handle web search
31
- if any(keyword in question.lower() for keyword in ["wikipedia", "video", "youtube", "article"]):
32
- trace += "Performing DuckDuckGo search...\n"
33
- summary = search_duckduckgo(question)
34
- trace += f"Summary from search: {summary}\n"
35
- answer = self.qa_pipeline(summary + "\n" + question, max_new_tokens=64)[0]['generated_text']
36
- return answer.strip(), trace
37
 
38
- trace += "General question. Using local model...\n"
39
- answer = self.qa_pipeline(question, max_new_tokens=64)[0]['generated_text']
40
- return answer.strip(), trace
 
1
+ import mimetypes
2
  from transformers import pipeline
3
  from tools.asr_tool import transcribe_audio
4
  from tools.excel_tool import analyze_excel
5
  from tools.search_tool import search_duckduckgo
 
6
 
7
  class GaiaAgent:
8
  def __init__(self):
9
  print("Loading model...")
10
+ self.llm = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", max_new_tokens=512, device="cpu")
11
+
12
+ def __call__(self, question: str, files: list = None):
13
+ trace = []
14
+ context = ""
15
 
16
+ if files:
17
+ for file in files:
18
+ mime, _ = mimetypes.guess_type(file.name)
19
+ if mime and mime.startswith("audio"):
20
+ transcription = transcribe_audio(file.name)
21
+ trace.append(f"Transcribed audio: {transcription}")
22
+ context += f"\nTranscription: {transcription}"
23
+ elif mime and ("spreadsheet" in mime or file.name.endswith(".xlsx")):
24
+ result = analyze_excel(file.name)
25
+ trace.append(f"Excel analysis: {result}")
26
+ context += f"\nSpreadsheet data: {result}"
27
 
28
+ if "http" in question or "Wikipedia" in question or "YouTube" in question or "search" in question.lower():
29
+ trace.append("Performing DuckDuckGo search...")
30
+ search_result = search_duckduckgo(question)
31
+ trace.append(f"Summary from search: {search_result}")
32
+ context += f"\nSearch Result: {search_result}"
 
 
33
 
34
+ # Include the original question
35
+ prompt = f"""
36
+ Answer the question based on the context below.
37
+ Context: {context}
38
+ Question: {question}
39
+ Answer:
40
+ """
41
+ response = self.llm(prompt)[0]['generated_text'].split("Answer:")[-1].strip()
42
+ trace.append(response)
43
 
44
+ return response, "\n".join(trace)
 
 
 
 
 
 
45
 
46
+