|
|
|
import sys |
|
import numpy as np |
|
import librosa |
|
from functools import lru_cache |
|
import time |
|
import logging |
|
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR |
|
from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split( |
|
"," |
|
) |
|
|
|
|
|
def create_tokenizer(lan): |
|
"""returns an object that has split function that works like the one of MosesTokenizer""" |
|
|
|
assert ( |
|
lan in WHISPER_LANG_CODES |
|
), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES) |
|
|
|
if lan == "uk": |
|
import tokenize_uk |
|
|
|
class UkrainianTokenizer: |
|
def split(self, text): |
|
return tokenize_uk.tokenize_sents(text) |
|
|
|
return UkrainianTokenizer() |
|
|
|
|
|
if ( |
|
lan |
|
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split() |
|
): |
|
from mosestokenizer import MosesSentenceSplitter |
|
|
|
return MosesSentenceSplitter(lan) |
|
|
|
|
|
if ( |
|
lan |
|
in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split() |
|
): |
|
logger.debug( |
|
f"{lan} code is not supported by wtpsplit. Going to use None lang_code option." |
|
) |
|
lan = None |
|
|
|
from wtpsplit import WtP |
|
|
|
|
|
wtp = WtP("wtp-canine-s-12l-no-adapters") |
|
|
|
class WtPtok: |
|
def split(self, sent): |
|
return wtp.split(sent, lang_code=lan) |
|
|
|
return WtPtok() |
|
|
|
|
|
def backend_factory(args): |
|
backend = args.backend |
|
if backend == "openai-api": |
|
logger.debug("Using OpenAI API.") |
|
asr = OpenaiApiASR(lan=args.lan) |
|
else: |
|
if backend == "faster-whisper": |
|
asr_cls = FasterWhisperASR |
|
elif backend == "mlx-whisper": |
|
asr_cls = MLXWhisper |
|
else: |
|
asr_cls = WhisperTimestampedASR |
|
|
|
|
|
size = args.model |
|
t = time.time() |
|
logger.info(f"Loading Whisper {size} model for language {args.lan}...") |
|
asr = asr_cls( |
|
modelsize=size, |
|
lan=args.lan, |
|
cache_dir=args.model_cache_dir, |
|
model_dir=args.model_dir, |
|
) |
|
e = time.time() |
|
logger.info(f"done. It took {round(e-t,2)} seconds.") |
|
|
|
|
|
if getattr(args, "vad", False): |
|
logger.info("Setting VAD filter") |
|
asr.use_vad() |
|
|
|
language = args.lan |
|
if args.task == "translate": |
|
asr.set_translate_task() |
|
tgt_language = "en" |
|
else: |
|
tgt_language = language |
|
|
|
|
|
if args.buffer_trimming == "sentence": |
|
|
|
tokenizer = create_tokenizer(tgt_language) |
|
else: |
|
tokenizer = None |
|
return asr, tokenizer |
|
|
|
def online_factory(args, asr, tokenizer, logfile=sys.stderr): |
|
if args.vac: |
|
online = VACOnlineASRProcessor( |
|
args.min_chunk_size, |
|
asr, |
|
tokenizer, |
|
logfile=logfile, |
|
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec), |
|
confidence_validation = args.confidence_validation |
|
) |
|
else: |
|
online = OnlineASRProcessor( |
|
asr, |
|
tokenizer, |
|
logfile=logfile, |
|
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec), |
|
confidence_validation = args.confidence_validation |
|
) |
|
return online |
|
|
|
def asr_factory(args, logfile=sys.stderr): |
|
""" |
|
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments. |
|
""" |
|
asr, tokenizer = backend_factory(args) |
|
online = online_factory(args, asr, tokenizer, logfile=logfile) |
|
return asr, online |
|
|
|
def warmup_asr(asr, warmup_file=None, timeout=5): |
|
""" |
|
Warmup the ASR model by transcribing a short audio file. |
|
""" |
|
import os |
|
import tempfile |
|
|
|
|
|
if warmup_file is None: |
|
|
|
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav" |
|
temp_dir = tempfile.gettempdir() |
|
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav") |
|
|
|
if not os.path.exists(warmup_file): |
|
logger.debug(f"Downloading warmup file from {jfk_url}") |
|
print(f"Downloading warmup file from {jfk_url}") |
|
import time |
|
import urllib.request |
|
import urllib.error |
|
import socket |
|
|
|
original_timeout = socket.getdefaulttimeout() |
|
socket.setdefaulttimeout(timeout) |
|
|
|
start_time = time.time() |
|
try: |
|
urllib.request.urlretrieve(jfk_url, warmup_file) |
|
logger.debug(f"Download successful in {time.time() - start_time:.2f}s") |
|
except (urllib.error.URLError, socket.timeout) as e: |
|
logger.warning(f"Download failed: {e}. Proceeding without warmup.") |
|
return False |
|
finally: |
|
socket.setdefaulttimeout(original_timeout) |
|
elif not warmup_file: |
|
return False |
|
|
|
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0: |
|
logger.warning(f"Warmup file {warmup_file} invalid or missing.") |
|
return False |
|
|
|
print(f"Warming up Whisper with {warmup_file}") |
|
try: |
|
import librosa |
|
audio, sr = librosa.load(warmup_file, sr=16000) |
|
except Exception as e: |
|
logger.warning(f"Failed to load audio file: {e}") |
|
return False |
|
|
|
|
|
asr.transcribe(audio) |
|
|
|
logger.info("Whisper is warmed up") |
|
|
|
|