Spaces:
Sleeping
Sleeping
File size: 7,616 Bytes
2417517 6d8f3bd 2417517 a71c6d0 f10feac 2417517 01fe41f 2417517 6d8f3bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import runpod
import base64
import faster_whisper
import tempfile
import logging
import torch
import sys
import requests
import os
import whisper_online
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)])
# Try to import the module
try:
logging.info("attempting to load whisper online")
from whisper_online import * # Replace 'some_module' with the actual module name
logging.info("Successfully imported whisper_online.")
except ImportError as e:
logging.error(f"Failed to import whisper_online: {e}", exc_info=True)
except Exception as e:
logging.error(f"Unknown from exception- error to import whisper_online: {e}", exc_info=True)
if torch.cuda.is_available():
logging.info(f"CUDA is available.")
else:
logging.info("CUDA is not available. Using CPU.")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = 'ivrit-ai/faster-whisper-v2-d3-e3'
logging.info(f"Selected model name: {model_name}")
#model = faster_whisper.WhisperModel(model_name, device=device)
try:
lan = 'he'
logging.info(f"Attempting to initialize FasterWhisperASR with device: {device}")
logging.info(f"Cache directory before: {tempfile.gettempdir()}") # Log the temp directory
cache_dir = os.environ.get('XDG_CACHE_HOME', tempfile.gettempdir())
logging.info(f"Cache directory after: {tempfile.gettempdir()}") # Log the temp directory
model = whisper_online.FasterWhisperASR(lan=lan, modelsize=model_name, cache_dir=cache_dir, model_dir=None)
logging.info("FasterWhisperASR model initialized successfully.")
except Exception as e:
logging.error(f"Falied to inilialize faster whisper model {e}")
# Maximum data size: 200MB
MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
def download_file(url, max_size_bytes, output_filename, api_key=None):
"""
Download a file from a given URL with size limit and optional API key.
Args:
url (str): The URL of the file to download.
max_size_bytes (int): Maximum allowed file size in bytes.
output_filename (str): The name of the file to save the download as.
api_key (str, optional): API key to be used as a bearer token.
Returns:
bool: True if download was successful, False otherwise.
"""
try:
# Prepare headers
headers = {}
if api_key:
headers['Authorization'] = f'Bearer {api_key}'
# Send a GET request
response = requests.get(url, stream=True, headers=headers)
response.raise_for_status() # Raises an HTTPError for bad requests
# Get the file size if possible
file_size = int(response.headers.get('Content-Length', 0))
if file_size > max_size_bytes:
print(f"File size ({file_size} bytes) exceeds the maximum allowed size ({max_size_bytes} bytes).")
return False
# Download and write the file
downloaded_size = 0
with open(output_filename, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
downloaded_size += len(chunk)
if downloaded_size > max_size_bytes:
print(f"Download stopped: Size limit exceeded ({max_size_bytes} bytes).")
return False
file.write(chunk)
print(f"File downloaded successfully: {output_filename}")
return True
except requests.RequestException as e:
print(f"Error downloading file: {e}")
return False
def transcribe(job):
datatype = job['input'].get('type', None)
if not datatype:
return {"error": "datatype field not provided. Should be 'blob' or 'url'."}
if not datatype in ['blob', 'url']:
return {"error": f"datatype should be 'blob' or 'url', but is {datatype} instead."}
# Get the API key from the job input
api_key = job['input'].get('api_key', None)
with tempfile.TemporaryDirectory() as d:
audio_file = f'{d}/audio.mp3'
if datatype == 'blob':
mp3_bytes = base64.b64decode(job['input']['data'])
open(audio_file, 'wb').write(mp3_bytes)
elif datatype == 'url':
success = download_file(job['input']['url'], MAX_PAYLOAD_SIZE, audio_file, api_key)
if not success:
return {"error": f"Error downloading data from {job['input']['url']}"}
result = transcribe_core(audio_file)
return {'result': result}
def transcribe_core(audio_file):
print('Transcribing...')
ret = {'segments': []}
segs, dummy = model.transcribe(audio_file, language='he', word_timestamps=True)
for s in segs:
words = []
for w in s.words:
words.append({'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability})
seg = {'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob,
'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words}
print(seg)
ret['segments'].append(seg)
return ret
#runpod.serverless.start({"handler": transcribe})
def transcribe_whisper(job):
logging.info(f"in triscribe-whisper")
datatype = job['input'].get('type', None)
if not datatype:
return {"error": "datatype field not provided. Should be 'blob' or 'url'."}
if not datatype in ['blob', 'url']:
return {"error": f"datatype should be 'blob' or 'url', but is {datatype} instead."}
# Get the API key from the job input
api_key = job['input'].get('api_key', None)
with tempfile.TemporaryDirectory() as d:
audio_file = f'{d}/audio.mp3'
if datatype == 'blob':
mp3_bytes = base64.b64decode(job['input']['data'])
open(audio_file, 'wb').write(mp3_bytes)
elif datatype == 'url':
success = download_file(job['input']['url'], MAX_PAYLOAD_SIZE, audio_file, api_key)
if not success:
return {"error": f"Error downloading data from {job['input']['url']}"}
logging.info("Starting transcription process using transcribe_core_whisper.")
result = transcribe_core_whisper(audio_file)
logging.info(f"DONE: in triscribe-whisper")
return {'result': result}
def transcribe_core_whisper(audio_file):
print('Transcribing...')
ret = {'segments': []}
try:
logging.debug(f"Transcribing audio file: {audio_file}")
segs = model.transcribe(audio_file, init_prompt="")
logging.info("Transcription completed successfully.")
for s in segs:
words = []
for w in s.words:
words.append({'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability})
seg = {'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob,
'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words}
logging.debug(f"All segments processed. Final transcription result: {ret}")
print(seg)
ret['segments'].append(seg)
except Exception as e:
# Log any exception that occurs during the transcription process
logging.error(f"Error during transcribe_core_whisper: {e}", exc_info=True)
return {"error": str(e)}
# Return the final result
logging.info("Transcription core function completed.")
return ret
#runpod.serverless.start({"handler": transcribe_whisper}) |