Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| from celery import Celery | |
| from typing import Union, Callable | |
| from whisper import tokenizer | |
| import tqdm | |
| from .util.audio import load_audio | |
| logging.basicConfig(format='[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s', level=logging.INFO, force=True) | |
| logger = logging.getLogger(__name__) | |
| # monkeypatch tqdm to fool whisper's `transcribe` function | |
| class _TQDM(tqdm.tqdm): | |
| _tqdm = tqdm.tqdm | |
| progress_function = None | |
| def __init__(self, *argv, total=0, unit="", **kwargs): | |
| logger.debug(f"Creating TQDM with total={total}, unit={unit}") | |
| self._total = total | |
| self._unit = unit | |
| self._progress = 0 | |
| self.progress_function = _TQDM.progress_function or None | |
| super().__init__(*argv, **kwargs) | |
| def set_progress_function(progress_function: Callable[[str, int, int], None]): | |
| logger.debug(f"Setting progress function to {progress_function}") | |
| _TQDM.progress_function = progress_function | |
| def update(self, progress): | |
| logger.debug(f"Updating TQDM with progress={progress}") | |
| self._progress += progress | |
| if self.progress_function is not None: | |
| self.progress_function(self._unit, self._total, self._progress) | |
| else: | |
| _TQDM._tqdm.update(self, progress) | |
| tqdm.tqdm = _TQDM | |
| ASR_ENGINE = os.getenv("ASR_ENGINE", "faster_whisper") | |
| if ASR_ENGINE == "faster_whisper": | |
| from .faster_whisper.core import load_model, transcribe as whisper_transcribe | |
| else: | |
| from .openai_whisper.core import load_model, transcribe as whisper_transcribe | |
| LANGUAGE_CODES = sorted(list(tokenizer.LANGUAGES.keys())) | |
| DEFAULT_MODEL_NAME = os.getenv("ASR_MODEL", "small") | |
| STATES = { | |
| 'loading_model': 'LOADING_MODEL', | |
| 'encoding': 'ENCODING', | |
| 'transcribing': 'TRANSCRIBING', | |
| } | |
| celery = Celery(__name__) | |
| celery.conf.broker_connection_retry_on_startup = True | |
| celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379") | |
| celery.conf.result_backend = os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379") | |
| celery.conf.worker_hijack_root_logger = False | |
| celery.conf.worker_redirect_stdouts_level = "DEBUG" | |
| def transcribe( | |
| self, | |
| audio_file_path: str, | |
| original_filename: str, | |
| asr_options: dict, | |
| ): | |
| logger.info(f"Transcribing {audio_file_path} with {asr_options}") | |
| output_format = asr_options["output"] | |
| with open(audio_file_path, "rb") as audio_file: | |
| _TQDM.set_progress_function(update_progress(self)) | |
| try: | |
| model_name = asr_options.get("model_name") or DEFAULT_MODEL_NAME | |
| logger.info(f"Loading model {model_name}") | |
| self.update_state(state=STATES["loading_model"], meta={"progress": {"units": "models", "total": 1, "current": 0}}) | |
| load_model(model_name) | |
| logger.info(f"Loading audio from {audio_file_path}") | |
| self.update_state(state=STATES["encoding"], meta={"progress": {"units": "files", "total": 1, "current": 0}}) | |
| audio_data = load_audio(audio_file, asr_options.get("encode", False)) | |
| logger.info(f"Transcribing audio") | |
| self.update_state(state=STATES["transcribing"], meta={"progress": {"units": "files", "total": 1, "current": 0}}) | |
| result = whisper_transcribe(audio_data, asr_options, output_format) | |
| finally: | |
| _TQDM.set_progress_function(None) | |
| logger.info(f"Transcription complete") | |
| os.remove(audio_file_path) | |
| filename = f"{original_filename.encode('latin-1', 'ignore').decode()}.{output_format}" | |
| output_directory = get_output_path(self.request.id) | |
| output_path = f"{output_directory}/{filename}" | |
| logger.info(f"Writing result to {output_path}") | |
| if not os.path.exists(output_directory): | |
| os.makedirs(output_directory) | |
| with open(output_path, "w") as f: | |
| f.write(result.read()) | |
| url_path = f"{get_output_url_path(transcribe.request.id)}/{filename}" | |
| return { | |
| "output_filename": filename, | |
| "output_path": output_path, | |
| "url_path": url_path, | |
| } | |
| def get_output_path(job_id: str): | |
| return os.environ.get("OUTPUT_DIRECTORY", os.getcwd() + "/app/output") + "/" + job_id | |
| def get_output_url_path(job_id: str): | |
| return os.environ.get("OUTPUT_URL_PREFIX", "/output") + "/" + job_id | |
| def update_progress(context): | |
| def do_update(units, total, current): | |
| logger.info(f"Updating progress with units={units}, total={total}, current={current}") | |
| context.update_state( | |
| state=STATES["transcribing"], | |
| meta={"progress": {"units": units, "total": total, "current": current}} | |
| ) | |
| return do_update | |