|
import sys
|
|
import logging
|
|
|
|
import io
|
|
import soundfile as sf
|
|
import math
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ASRBase:
|
|
sep = " "
|
|
|
|
|
|
def __init__(
|
|
self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr
|
|
):
|
|
self.logfile = logfile
|
|
|
|
self.transcribe_kargs = {}
|
|
if lan == "auto":
|
|
self.original_language = None
|
|
else:
|
|
self.original_language = lan
|
|
|
|
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
|
|
|
def load_model(self, modelsize, cache_dir):
|
|
raise NotImplemented("must be implemented in the child class")
|
|
|
|
def transcribe(self, audio, init_prompt=""):
|
|
raise NotImplemented("must be implemented in the child class")
|
|
|
|
def use_vad(self):
|
|
raise NotImplemented("must be implemented in the child class")
|
|
|
|
|
|
class WhisperTimestampedASR(ASRBase):
|
|
"""Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
|
|
On the other hand, the installation for GPU could be easier.
|
|
"""
|
|
|
|
sep = " "
|
|
|
|
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
|
import whisper
|
|
import whisper_timestamped
|
|
from whisper_timestamped import transcribe_timestamped
|
|
|
|
self.transcribe_timestamped = transcribe_timestamped
|
|
if model_dir is not None:
|
|
logger.debug("ignoring model_dir, not implemented")
|
|
return whisper.load_model(modelsize, download_root=cache_dir)
|
|
|
|
def transcribe(self, audio, init_prompt=""):
|
|
result = self.transcribe_timestamped(
|
|
self.model,
|
|
audio,
|
|
language=self.original_language,
|
|
initial_prompt=init_prompt,
|
|
verbose=None,
|
|
condition_on_previous_text=True,
|
|
**self.transcribe_kargs,
|
|
)
|
|
return result
|
|
|
|
def ts_words(self, r):
|
|
|
|
o = []
|
|
for s in r["segments"]:
|
|
for w in s["words"]:
|
|
t = (w["start"], w["end"], w["text"])
|
|
o.append(t)
|
|
return o
|
|
|
|
def segments_end_ts(self, res):
|
|
return [s["end"] for s in res["segments"]]
|
|
|
|
def use_vad(self):
|
|
self.transcribe_kargs["vad"] = True
|
|
|
|
def set_translate_task(self):
|
|
self.transcribe_kargs["task"] = "translate"
|
|
|
|
|
|
class FasterWhisperASR(ASRBase):
|
|
"""Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version."""
|
|
|
|
sep = ""
|
|
|
|
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
|
from faster_whisper import WhisperModel
|
|
|
|
|
|
if model_dir is not None:
|
|
logger.debug(
|
|
f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used."
|
|
)
|
|
model_size_or_path = model_dir
|
|
elif modelsize is not None:
|
|
model_size_or_path = modelsize
|
|
else:
|
|
raise ValueError("modelsize or model_dir parameter must be set")
|
|
|
|
|
|
model = WhisperModel(
|
|
model_size_or_path,
|
|
device="cuda",
|
|
compute_type="float16",
|
|
download_root=cache_dir,
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
def transcribe(self, audio, init_prompt=""):
|
|
|
|
|
|
segments, info = self.model.transcribe(
|
|
audio,
|
|
language=self.original_language,
|
|
initial_prompt=init_prompt,
|
|
beam_size=5,
|
|
word_timestamps=True,
|
|
condition_on_previous_text=True,
|
|
**self.transcribe_kargs,
|
|
)
|
|
|
|
|
|
return list(segments)
|
|
|
|
def ts_words(self, segments):
|
|
o = []
|
|
for segment in segments:
|
|
for word in segment.words:
|
|
if segment.no_speech_prob > 0.9:
|
|
continue
|
|
|
|
w = word.word
|
|
t = (word.start, word.end, w)
|
|
o.append(t)
|
|
return o
|
|
|
|
def segments_end_ts(self, res):
|
|
return [s.end for s in res]
|
|
|
|
def use_vad(self):
|
|
self.transcribe_kargs["vad_filter"] = True
|
|
|
|
def set_translate_task(self):
|
|
self.transcribe_kargs["task"] = "translate"
|
|
|
|
|
|
class MLXWhisper(ASRBase):
|
|
"""
|
|
Uses MPX Whisper library as the backend, optimized for Apple Silicon.
|
|
Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
|
|
Significantly faster than faster-whisper (without CUDA) on Apple M1.
|
|
"""
|
|
|
|
sep = ""
|
|
|
|
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
|
"""
|
|
Loads the MLX-compatible Whisper model.
|
|
|
|
Args:
|
|
modelsize (str, optional): The size or name of the Whisper model to load.
|
|
If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
|
|
Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
|
|
cache_dir (str, optional): Path to the directory for caching models.
|
|
**Note**: This is not supported by MLX Whisper and will be ignored.
|
|
model_dir (str, optional): Direct path to a custom model directory.
|
|
If specified, it overrides the `modelsize` parameter.
|
|
"""
|
|
from mlx_whisper.transcribe import ModelHolder, transcribe
|
|
import mlx.core as mx
|
|
|
|
if model_dir is not None:
|
|
logger.debug(
|
|
f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used."
|
|
)
|
|
model_size_or_path = model_dir
|
|
elif modelsize is not None:
|
|
model_size_or_path = self.translate_model_name(modelsize)
|
|
logger.debug(
|
|
f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used."
|
|
)
|
|
|
|
self.model_size_or_path = model_size_or_path
|
|
|
|
|
|
|
|
|
|
dtype = mx.float16
|
|
ModelHolder.get_model(model_size_or_path, dtype)
|
|
return transcribe
|
|
|
|
def translate_model_name(self, model_name):
|
|
"""
|
|
Translates a given model name to its corresponding MLX-compatible model path.
|
|
|
|
Args:
|
|
model_name (str): The name of the model to translate.
|
|
|
|
Returns:
|
|
str: The MLX-compatible model path.
|
|
"""
|
|
|
|
model_mapping = {
|
|
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
|
"tiny": "mlx-community/whisper-tiny-mlx",
|
|
"base.en": "mlx-community/whisper-base.en-mlx",
|
|
"base": "mlx-community/whisper-base-mlx",
|
|
"small.en": "mlx-community/whisper-small.en-mlx",
|
|
"small": "mlx-community/whisper-small-mlx",
|
|
"medium.en": "mlx-community/whisper-medium.en-mlx",
|
|
"medium": "mlx-community/whisper-medium-mlx",
|
|
"large-v1": "mlx-community/whisper-large-v1-mlx",
|
|
"large-v2": "mlx-community/whisper-large-v2-mlx",
|
|
"large-v3": "mlx-community/whisper-large-v3-mlx",
|
|
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
|
"large": "mlx-community/whisper-large-mlx",
|
|
}
|
|
|
|
|
|
mlx_model_path = model_mapping.get(model_name)
|
|
|
|
if mlx_model_path:
|
|
return mlx_model_path
|
|
else:
|
|
raise ValueError(
|
|
f"Model name '{model_name}' is not recognized or not supported."
|
|
)
|
|
|
|
def transcribe(self, audio, init_prompt=""):
|
|
if self.transcribe_kargs:
|
|
logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
|
|
segments = self.model(
|
|
audio,
|
|
language=self.original_language,
|
|
initial_prompt=init_prompt,
|
|
word_timestamps=True,
|
|
condition_on_previous_text=True,
|
|
path_or_hf_repo=self.model_size_or_path,
|
|
)
|
|
return segments.get("segments", [])
|
|
|
|
def ts_words(self, segments):
|
|
"""
|
|
Extract timestamped words from transcription segments and skips words with high no-speech probability.
|
|
"""
|
|
return [
|
|
(word["start"], word["end"], word["word"])
|
|
for segment in segments
|
|
for word in segment.get("words", [])
|
|
if segment.get("no_speech_prob", 0) <= 0.9
|
|
]
|
|
|
|
def segments_end_ts(self, res):
|
|
return [s["end"] for s in res]
|
|
|
|
def use_vad(self):
|
|
self.transcribe_kargs["vad_filter"] = True
|
|
|
|
def set_translate_task(self):
|
|
self.transcribe_kargs["task"] = "translate"
|
|
|
|
|
|
class OpenaiApiASR(ASRBase):
|
|
"""Uses OpenAI's Whisper API for audio transcription."""
|
|
|
|
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
|
|
self.logfile = logfile
|
|
|
|
self.modelname = "whisper-1"
|
|
self.original_language = (
|
|
None if lan == "auto" else lan
|
|
)
|
|
self.response_format = "verbose_json"
|
|
self.temperature = temperature
|
|
|
|
self.load_model()
|
|
|
|
self.use_vad_opt = False
|
|
|
|
|
|
self.task = "transcribe"
|
|
|
|
def load_model(self, *args, **kwargs):
|
|
from openai import OpenAI
|
|
|
|
self.client = OpenAI()
|
|
|
|
self.transcribed_seconds = (
|
|
0
|
|
)
|
|
|
|
def ts_words(self, segments):
|
|
no_speech_segments = []
|
|
if self.use_vad_opt:
|
|
for segment in segments.segments:
|
|
|
|
if segment["no_speech_prob"] > 0.8:
|
|
no_speech_segments.append(
|
|
(segment.get("start"), segment.get("end"))
|
|
)
|
|
|
|
o = []
|
|
for word in segments.words:
|
|
start = word.start
|
|
end = word.end
|
|
if any(s[0] <= start <= s[1] for s in no_speech_segments):
|
|
|
|
continue
|
|
o.append((start, end, word.word))
|
|
return o
|
|
|
|
def segments_end_ts(self, res):
|
|
return [s.end for s in res.words]
|
|
|
|
def transcribe(self, audio_data, prompt=None, *args, **kwargs):
|
|
|
|
buffer = io.BytesIO()
|
|
buffer.name = "temp.wav"
|
|
sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
|
|
buffer.seek(0)
|
|
|
|
self.transcribed_seconds += math.ceil(
|
|
len(audio_data) / 16000
|
|
)
|
|
|
|
params = {
|
|
"model": self.modelname,
|
|
"file": buffer,
|
|
"response_format": self.response_format,
|
|
"temperature": self.temperature,
|
|
"timestamp_granularities": ["word", "segment"],
|
|
}
|
|
if self.task != "translate" and self.original_language:
|
|
params["language"] = self.original_language
|
|
if prompt:
|
|
params["prompt"] = prompt
|
|
|
|
if self.task == "translate":
|
|
proc = self.client.audio.translations
|
|
else:
|
|
proc = self.client.audio.transcriptions
|
|
|
|
|
|
transcript = proc.create(**params)
|
|
logger.debug(
|
|
f"OpenAI API processed accumulated {self.transcribed_seconds} seconds"
|
|
)
|
|
|
|
return transcript
|
|
|
|
def use_vad(self):
|
|
self.use_vad_opt = True
|
|
|
|
def set_translate_task(self):
|
|
self.task = "translate"
|
|
|