Freddolin commited on
Commit
e258602
·
verified ·
1 Parent(s): dad31a4

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +23 -10
agent.py CHANGED
@@ -1,11 +1,14 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
 
 
 
 
3
  from ddgs import DDGS
4
- import re
5
  import pandas as pd
6
- import tempfile
7
  import os
8
- import whisper
9
 
10
  SYSTEM_PROMPT = """
11
  You are a helpful AI assistant. Think step by step to solve the problem. If the question requires reasoning, perform it. If it refers to a search or file, use the result provided. At the end, return ONLY the final answer string. No explanations.
@@ -17,7 +20,19 @@ class GaiaAgent:
17
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
18
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  self.model.to(self.device)
20
- self.transcriber = whisper.load_model("base")
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def search(self, query: str) -> str:
23
  try:
@@ -25,13 +40,13 @@ class GaiaAgent:
25
  results = list(ddgs.text(query, safesearch="off"))
26
  if results:
27
  return results[0]['body']
28
- except Exception as e:
29
  return ""
30
  return ""
31
 
32
  def transcribe_audio(self, file_path: str) -> str:
33
  try:
34
- result = self.transcriber.transcribe(file_path)
35
  return result['text']
36
  except Exception:
37
  return ""
@@ -52,7 +67,7 @@ class GaiaAgent:
52
  context = ""
53
  if files:
54
  for filename, filepath in files.items():
55
- if filename.endswith(".mp3"):
56
  context = self.transcribe_audio(filepath)
57
  break
58
  elif filename.endswith(".xlsx"):
@@ -67,7 +82,6 @@ class GaiaAgent:
67
  **inputs,
68
  max_new_tokens=128,
69
  do_sample=False,
70
- temperature=0.0,
71
  pad_token_id=self.tokenizer.pad_token_id
72
  )
73
  output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -75,5 +89,4 @@ class GaiaAgent:
75
  return final, final
76
  except Exception as e:
77
  return "ERROR", f"Agent failed: {e}"
78
-
79
 
 
1
  import torch
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForSeq2SeqLM,
5
+ pipeline,
6
+ AutoProcessor,
7
+ AutoModelForSpeechSeq2Seq
8
+ )
9
  from ddgs import DDGS
 
10
  import pandas as pd
 
11
  import os
 
12
 
13
  SYSTEM_PROMPT = """
14
  You are a helpful AI assistant. Think step by step to solve the problem. If the question requires reasoning, perform it. If it refers to a search or file, use the result provided. At the end, return ONLY the final answer string. No explanations.
 
20
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
21
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  self.model.to(self.device)
23
+
24
+ # Whisper via HF
25
+ self.asr_model_id = "openai/whisper-small"
26
+ self.asr_processor = AutoProcessor.from_pretrained(self.asr_model_id)
27
+ self.asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(self.asr_model_id).to(self.device)
28
+ self.pipe = pipeline(
29
+ "automatic-speech-recognition",
30
+ model=self.asr_model,
31
+ tokenizer=self.asr_processor.tokenizer,
32
+ feature_extractor=self.asr_processor.feature_extractor,
33
+ return_timestamps=False,
34
+ device=0 if torch.cuda.is_available() else -1
35
+ )
36
 
37
  def search(self, query: str) -> str:
38
  try:
 
40
  results = list(ddgs.text(query, safesearch="off"))
41
  if results:
42
  return results[0]['body']
43
+ except Exception:
44
  return ""
45
  return ""
46
 
47
  def transcribe_audio(self, file_path: str) -> str:
48
  try:
49
+ result = self.pipe(file_path)
50
  return result['text']
51
  except Exception:
52
  return ""
 
67
  context = ""
68
  if files:
69
  for filename, filepath in files.items():
70
+ if filename.endswith(".mp3") or filename.endswith(".wav"):
71
  context = self.transcribe_audio(filepath)
72
  break
73
  elif filename.endswith(".xlsx"):
 
82
  **inputs,
83
  max_new_tokens=128,
84
  do_sample=False,
 
85
  pad_token_id=self.tokenizer.pad_token_id
86
  )
87
  output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
89
  return final, final
90
  except Exception as e:
91
  return "ERROR", f"Agent failed: {e}"
 
92