Freddolin commited on
Commit
ee62c26
·
verified ·
1 Parent(s): 987f2c6

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +38 -90
agent.py CHANGED
@@ -1,92 +1,40 @@
1
- import torch
2
- from transformers import (
3
- AutoTokenizer,
4
- AutoModelForSeq2SeqLM,
5
- pipeline,
6
- AutoProcessor,
7
- AutoModelForSpeechSeq2Seq
8
- )
9
- from duckduckgo_search import DDGS
10
- import pandas as pd
11
- import os
12
-
13
- SYSTEM_PROMPT = """
14
- You are a helpful AI assistant. Think step by step to solve the problem. If the question requires reasoning, perform it. If it refers to a search or file, use the result provided. At the end, return ONLY the final answer string. No explanations.
15
- """
16
 
17
  class GaiaAgent:
18
- def __init__(self, model_id="google/flan-t5-base"):
19
- self.tokenizer = AutoTokenizer.from_pretrained(model_id)
20
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
21
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
- self.model.to(self.device)
23
-
24
- # Whisper via HF
25
- self.asr_model_id = "openai/whisper-small"
26
- self.asr_processor = AutoProcessor.from_pretrained(self.asr_model_id)
27
- self.asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(self.asr_model_id).to(self.device)
28
- self.pipe = pipeline(
29
- "automatic-speech-recognition",
30
- model=self.asr_model,
31
- tokenizer=self.asr_processor.tokenizer,
32
- feature_extractor=self.asr_processor.feature_extractor,
33
- return_timestamps=False,
34
- device=0 if torch.cuda.is_available() else -1
35
- )
36
-
37
- def search(self, query: str) -> str:
38
- try:
39
- with DDGS() as ddgs:
40
- results = ddgs.text("your query", max_results=1)
41
- if results:
42
- return results[0]['body']
43
- except Exception:
44
- return ""
45
- return ""
46
-
47
- def transcribe_audio(self, file_path: str) -> str:
48
- try:
49
- result = self.pipe(file_path)
50
- return result['text']
51
- except Exception:
52
- return ""
53
-
54
- def handle_excel(self, file_path: str) -> str:
55
- try:
56
- df = pd.read_excel(file_path)
57
- df.columns = [col.lower() for col in df.columns]
58
- if 'category' in df.columns and 'sales' in df.columns:
59
- food_sales = df[df['category'].str.lower() != 'drink']['sales'].sum()
60
- return f"{food_sales:.2f}"
61
- except Exception:
62
- return ""
63
- return ""
64
-
65
- def __call__(self, question: str, files: dict = None) -> tuple[str, str]:
66
- try:
67
- context = ""
68
- if files:
69
- for filename, filepath in files.items():
70
- if filename.endswith(".mp3") or filename.endswith(".wav"):
71
- context = self.transcribe_audio(filepath)
72
- break
73
- elif filename.endswith(".xlsx"):
74
- excel_result = self.handle_excel(filepath)
75
- return excel_result.strip(), excel_result.strip()
76
- elif "http" in question.lower() or "wikipedia" in question.lower():
77
- context = self.search(question)
78
-
79
- prompt = f"{SYSTEM_PROMPT}\n\n{context}\n\nQuestion: {question.strip()}"
80
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
81
- outputs = self.model.generate(
82
- **inputs,
83
- max_new_tokens=128,
84
- do_sample=False,
85
- pad_token_id=self.tokenizer.pad_token_id
86
- )
87
- output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
88
- final = output_text.strip()
89
- return final, final
90
- except Exception as e:
91
- return "ERROR", f"Agent failed: {e}"
92
-
 
1
+ from transformers import pipeline
2
+ from tools.asr_tool import transcribe_audio
3
+ from tools.excel_tool import analyze_excel
4
+ from tools.search_tool import search_duckduckgo
5
+ import mimetypes
 
 
 
 
 
 
 
 
 
 
6
 
7
  class GaiaAgent:
8
+ def __init__(self):
9
+ print("Loading model...")
10
+ self.qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")
11
+
12
+ def __call__(self, question: str):
13
+ trace = ""
14
+
15
+ # Handle audio
16
+ if question.lower().strip().endswith(('.mp3', '.wav')):
17
+ trace += "Audio detected. Running transcription...\n"
18
+ text = transcribe_audio(question.strip())
19
+ trace += f"Transcribed text: {text}\n"
20
+ answer = self.qa_pipeline(text, max_new_tokens=64)[0]['generated_text']
21
+ return answer.strip(), trace
22
+
23
+ # Handle Excel
24
+ if question.lower().strip().endswith(('.xls', '.xlsx')):
25
+ trace += "Excel detected. Running analysis...\n"
26
+ answer = analyze_excel(question.strip())
27
+ trace += f"Extracted value: {answer}\n"
28
+ return answer.strip(), trace
29
+
30
+ # Handle web search
31
+ if any(keyword in question.lower() for keyword in ["wikipedia", "video", "youtube", "article"]):
32
+ trace += "Performing DuckDuckGo search...\n"
33
+ summary = search_duckduckgo(question)
34
+ trace += f"Summary from search: {summary}\n"
35
+ answer = self.qa_pipeline(summary + "\n" + question, max_new_tokens=64)[0]['generated_text']
36
+ return answer.strip(), trace
37
+
38
+ trace += "General question. Using local model...\n"
39
+ answer = self.qa_pipeline(question, max_new_tokens=64)[0]['generated_text']
40
+ return answer.strip(), trace