Spaces:
Sleeping
Sleeping
| import urllib | |
| import os | |
| from typing import List | |
| from urllib.parse import urlparse | |
| import json5 | |
| import torch | |
| from tqdm import tqdm | |
| from src.conversion.hf_converter import convert_hf_whisper | |
| class ModelConfig: | |
| def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"): | |
| """ | |
| Initialize a model configuration. | |
| name: Name of the model | |
| url: URL to download the model from | |
| path: Path to the model file. If not set, the model will be downloaded from the URL. | |
| type: Type of model. Can be whisper or huggingface. | |
| """ | |
| self.name = name | |
| self.url = url | |
| self.path = path | |
| self.type = type | |
| def download_url(self, root_dir: str): | |
| import whisper | |
| # See if path is already set | |
| if self.path is not None: | |
| return self.path | |
| if root_dir is None: | |
| root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper") | |
| model_type = self.type.lower() if self.type is not None else "whisper" | |
| if model_type in ["huggingface", "hf"]: | |
| self.path = self.url | |
| destination_target = os.path.join(root_dir, self.name + ".pt") | |
| # Convert from HuggingFace format to Whisper format | |
| if os.path.exists(destination_target): | |
| print(f"File {destination_target} already exists, skipping conversion") | |
| else: | |
| print("Saving HuggingFace model in Whisper format to " + destination_target) | |
| convert_hf_whisper(self.url, destination_target) | |
| self.path = destination_target | |
| elif model_type in ["whisper", "w"]: | |
| self.path = self.url | |
| # See if URL is just a file | |
| if self.url in whisper._MODELS: | |
| # No need to download anything - Whisper will handle it | |
| self.path = self.url | |
| elif self.url.startswith("file://"): | |
| # Get file path | |
| self.path = urlparse(self.url).path | |
| # See if it is an URL | |
| elif self.url.startswith("http://") or self.url.startswith("https://"): | |
| # Extension (or file name) | |
| extension = os.path.splitext(self.url)[-1] | |
| download_target = os.path.join(root_dir, self.name + extension) | |
| if os.path.exists(download_target) and not os.path.isfile(download_target): | |
| raise RuntimeError(f"{download_target} exists and is not a regular file") | |
| if not os.path.isfile(download_target): | |
| self._download_file(self.url, download_target) | |
| else: | |
| print(f"File {download_target} already exists, skipping download") | |
| self.path = download_target | |
| # Must be a local file | |
| else: | |
| self.path = self.url | |
| else: | |
| raise ValueError(f"Unknown model type {model_type}") | |
| return self.path | |
| def _download_file(self, url: str, destination: str): | |
| with urllib.request.urlopen(url) as source, open(destination, "wb") as output: | |
| with tqdm( | |
| total=int(source.info().get("Content-Length")), | |
| ncols=80, | |
| unit="iB", | |
| unit_scale=True, | |
| unit_divisor=1024, | |
| ) as loop: | |
| while True: | |
| buffer = source.read(8192) | |
| if not buffer: | |
| break | |
| output.write(buffer) | |
| loop.update(len(buffer)) | |
| class ApplicationConfig: | |
| def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600, | |
| share: bool = False, server_name: str = None, server_port: int = 7860, delete_uploaded_files: bool = True, | |
| default_model_name: str = "medium", default_vad: str = "silero-vad", | |
| vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800, | |
| auto_parallel: bool = False, output_dir: str = None, | |
| model_dir: str = None, device: str = None, | |
| verbose: bool = True, task: str = "transcribe", language: str = None, | |
| vad_merge_window: float = 5, vad_max_merge_size: float = 30, | |
| vad_padding: float = 1, vad_prompt_window: float = 3, | |
| temperature: float = 0, best_of: int = 5, beam_size: int = 5, | |
| patience: float = None, length_penalty: float = None, | |
| suppress_tokens: str = "-1", initial_prompt: str = None, | |
| condition_on_previous_text: bool = True, fp16: bool = True, | |
| temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4, | |
| logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6): | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.models = models | |
| # WebUI settings | |
| self.input_audio_max_duration = input_audio_max_duration | |
| self.share = share | |
| self.server_name = server_name | |
| self.server_port = server_port | |
| self.delete_uploaded_files = delete_uploaded_files | |
| self.default_model_name = default_model_name | |
| self.default_vad = default_vad | |
| self.vad_parallel_devices = vad_parallel_devices | |
| self.vad_cpu_cores = vad_cpu_cores | |
| self.vad_process_timeout = vad_process_timeout | |
| self.auto_parallel = auto_parallel | |
| self.output_dir = output_dir | |
| self.model_dir = model_dir | |
| self.device = device | |
| self.verbose = verbose | |
| self.task = task | |
| self.language = language | |
| self.vad_merge_window = vad_merge_window | |
| self.vad_max_merge_size = vad_max_merge_size | |
| self.vad_padding = vad_padding | |
| self.vad_prompt_window = vad_prompt_window | |
| self.temperature = temperature | |
| self.best_of = best_of | |
| self.beam_size = beam_size | |
| self.patience = patience | |
| self.length_penalty = length_penalty | |
| self.suppress_tokens = suppress_tokens | |
| self.initial_prompt = initial_prompt | |
| self.condition_on_previous_text = condition_on_previous_text | |
| self.fp16 = fp16 | |
| self.temperature_increment_on_fallback = temperature_increment_on_fallback | |
| self.compression_ratio_threshold = compression_ratio_threshold | |
| self.logprob_threshold = logprob_threshold | |
| self.no_speech_threshold = no_speech_threshold | |
| def get_model_names(self): | |
| return [ x.name for x in self.models ] | |
| def update(self, **new_values): | |
| result = ApplicationConfig(**self.__dict__) | |
| for key, value in new_values.items(): | |
| setattr(result, key, value) | |
| return result | |
| def create_default(**kwargs): | |
| app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5")) | |
| # Update with kwargs | |
| if len(kwargs) > 0: | |
| app_config = app_config.update(**kwargs) | |
| return app_config | |
| def parse_file(config_path: str): | |
| import json5 | |
| with open(config_path, "r") as f: | |
| # Load using json5 | |
| data = json5.load(f) | |
| data_models = data.pop("models", []) | |
| models = [ ModelConfig(**x) for x in data_models ] | |
| return ApplicationConfig(models, **data) | |