|
import mimetypes |
|
from transformers import pipeline |
|
from tools.asr_tool import transcribe_audio |
|
from tools.excel_tool import analyze_excel |
|
from tools.search_tool import search_duckduckgo |
|
|
|
class GaiaAgent: |
|
def __init__(self): |
|
print("Loading model...") |
|
self.llm = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", max_new_tokens=512, device="cpu") |
|
|
|
def __call__(self, question: str, files: list = None): |
|
trace = [] |
|
context = "" |
|
|
|
if files: |
|
for file in files: |
|
mime, _ = mimetypes.guess_type(file.name) |
|
if mime and mime.startswith("audio"): |
|
transcription = transcribe_audio(file.name) |
|
trace.append(f"Transcribed audio: {transcription}") |
|
context += f"\nTranscription: {transcription}" |
|
elif mime and ("spreadsheet" in mime or file.name.endswith(".xlsx")): |
|
result = analyze_excel(file.name) |
|
trace.append(f"Excel analysis: {result}") |
|
context += f"\nSpreadsheet data: {result}" |
|
|
|
if "http" in question or "Wikipedia" in question or "YouTube" in question or "search" in question.lower(): |
|
trace.append("Performing DuckDuckGo search...") |
|
search_result = search_duckduckgo(question) |
|
trace.append(f"Summary from search: {search_result}") |
|
context += f"\nSearch Result: {search_result}" |
|
|
|
|
|
prompt = f""" |
|
Answer the question based on the context below. |
|
Context: {context} |
|
Question: {question} |
|
Answer: |
|
""" |
|
response = self.llm(prompt)[0]['generated_text'].split("Answer:")[-1].strip() |
|
trace.append(response) |
|
|
|
return response, "\n".join(trace) |
|
|
|
|