dlaima commited on
Commit
86758ba
·
verified ·
1 Parent(s): 9571928

Update audio_transcriber.py

Browse files
Files changed (1) hide show
  1. audio_transcriber.py +29 -19
audio_transcriber.py CHANGED
@@ -1,35 +1,45 @@
1
  import os
2
  import requests
3
- import openai
4
  from smolagents import Tool
5
 
6
- openai.api_key = os.getenv("OPENAI_API_KEY")
7
-
8
  class AudioTranscriptionTool(Tool):
9
  name = "audio_transcriber"
10
- description = "Transcribe a given audio file in mp3 or wav format to text using Whisper."
11
  inputs = {
12
- "url": {
13
  "type": "string",
14
- "description": "URL to the audio file (.mp3 or .wav)"
15
  }
16
  }
17
  output_type = "string"
18
 
19
- def forward(self, url: str) -> str:
 
 
 
 
 
 
 
20
  try:
21
- # Download audio
22
- filename = "/tmp/audio_input.mp3"
23
- response = requests.get(url)
24
- with open(filename, "wb") as f:
25
- f.write(response.content)
26
 
27
- # Transcribe with Whisper
28
- with open(filename, "rb") as audio_file:
29
- transcript = openai.audio.transcriptions.create(
30
- model="whisper-1",
31
- file=audio_file
32
- )
33
- return transcript.text.strip()
 
 
 
 
 
 
 
 
 
34
  except Exception as e:
35
  return f"Error transcribing audio: {e}"
 
1
  import os
2
  import requests
 
3
  from smolagents import Tool
4
 
 
 
5
  class AudioTranscriptionTool(Tool):
6
  name = "audio_transcriber"
7
+ description = "Transcribe a given audio file in mp3 or wav format to text using Whisper via Hugging Face API."
8
  inputs = {
9
+ "file_path": {
10
  "type": "string",
11
+ "description": "Path to the audio file (must be .mp3 or .wav)"
12
  }
13
  }
14
  output_type = "string"
15
 
16
+ def __init__(self):
17
+ super().__init__()
18
+ self.api_url = "https://api-inference.huggingface.co/models/openai/whisper-large"
19
+ self.headers = {
20
+ "Authorization": f"Bearer {os.getenv('HF_API_TOKEN')}"
21
+ }
22
+
23
+ def forward(self, file_path: str) -> str:
24
  try:
25
+ with open(file_path, "rb") as audio_file:
26
+ audio_bytes = audio_file.read()
 
 
 
27
 
28
+ response = requests.post(
29
+ self.api_url,
30
+ headers=self.headers,
31
+ data=audio_bytes,
32
+ timeout=60
33
+ )
34
+ if response.status_code == 200:
35
+ result = response.json()
36
+ # The exact key depends on the model; usually 'text' for whisper
37
+ transcription = result.get("text", None)
38
+ if transcription:
39
+ return transcription.strip()
40
+ else:
41
+ return "Error: No transcription found in the response."
42
+ else:
43
+ return f"Error transcribing audio: {response.status_code} {response.text}"
44
  except Exception as e:
45
  return f"Error transcribing audio: {e}"