Spaces:
Configuration error
Configuration error
Fedir Zadniprovskyi
commited on
Commit
·
3a14175
1
Parent(s):
e9aef91
feat: support model preloading (#66)
Browse files- faster_whisper_server/config.py +23 -2
- faster_whisper_server/main.py +12 -3
faster_whisper_server/config.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import enum
|
|
|
|
| 2 |
|
| 3 |
-
from pydantic import BaseModel, Field
|
| 4 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 5 |
|
| 6 |
SAMPLES_PER_SECOND = 16000
|
|
@@ -151,7 +152,9 @@ class WhisperConfig(BaseModel):
|
|
| 151 |
|
| 152 |
model: str = Field(default="Systran/faster-whisper-medium.en")
|
| 153 |
"""
|
| 154 |
-
Huggingface model to use for transcription. Note, the model must support being ran using CTranslate2.
|
|
|
|
|
|
|
| 155 |
Models created by authors of `faster-whisper` can be found at https://huggingface.co/Systran
|
| 156 |
You can find other supported models at https://huggingface.co/models?p=2&sort=trending&search=ctranslate2 and https://huggingface.co/models?sort=trending&search=ct2
|
| 157 |
"""
|
|
@@ -199,6 +202,16 @@ class Config(BaseSettings):
|
|
| 199 |
"""
|
| 200 |
Maximum number of models that can be loaded at a time.
|
| 201 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
max_no_data_seconds: float = 1.0
|
| 203 |
"""
|
| 204 |
Max duration to wait for the next audio chunk before transcription is finilized and connection is closed.
|
|
@@ -218,5 +231,13 @@ class Config(BaseSettings):
|
|
| 218 |
Should be greater than `max_inactivity_seconds`
|
| 219 |
"""
|
| 220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
config = Config()
|
|
|
|
| 1 |
import enum
|
| 2 |
+
from typing import Self
|
| 3 |
|
| 4 |
+
from pydantic import BaseModel, Field, model_validator
|
| 5 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 6 |
|
| 7 |
SAMPLES_PER_SECOND = 16000
|
|
|
|
| 152 |
|
| 153 |
model: str = Field(default="Systran/faster-whisper-medium.en")
|
| 154 |
"""
|
| 155 |
+
Default Huggingface model to use for transcription. Note, the model must support being ran using CTranslate2.
|
| 156 |
+
This model will be used if no model is specified in the request.
|
| 157 |
+
|
| 158 |
Models created by authors of `faster-whisper` can be found at https://huggingface.co/Systran
|
| 159 |
You can find other supported models at https://huggingface.co/models?p=2&sort=trending&search=ctranslate2 and https://huggingface.co/models?sort=trending&search=ct2
|
| 160 |
"""
|
|
|
|
| 202 |
"""
|
| 203 |
Maximum number of models that can be loaded at a time.
|
| 204 |
"""
|
| 205 |
+
preload_models: list[str] = Field(
|
| 206 |
+
default_factory=list,
|
| 207 |
+
examples=[
|
| 208 |
+
["Systran/faster-whisper-medium.en"],
|
| 209 |
+
["Systran/faster-whisper-medium.en", "Systran/faster-whisper-small.en"],
|
| 210 |
+
],
|
| 211 |
+
)
|
| 212 |
+
"""
|
| 213 |
+
List of models to preload on startup. Shouldn't be greater than `max_models`. By default, the model is first loaded on first request.
|
| 214 |
+
""" # noqa: E501
|
| 215 |
max_no_data_seconds: float = 1.0
|
| 216 |
"""
|
| 217 |
Max duration to wait for the next audio chunk before transcription is finilized and connection is closed.
|
|
|
|
| 231 |
Should be greater than `max_inactivity_seconds`
|
| 232 |
"""
|
| 233 |
|
| 234 |
+
@model_validator(mode="after")
|
| 235 |
+
def ensure_preloaded_models_is_lte_max_models(self) -> Self:
|
| 236 |
+
if len(self.preload_models) > self.max_models:
|
| 237 |
+
raise ValueError(
|
| 238 |
+
f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})" # noqa: E501
|
| 239 |
+
)
|
| 240 |
+
return self
|
| 241 |
+
|
| 242 |
|
| 243 |
config = Config()
|
faster_whisper_server/main.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
from collections import OrderedDict
|
|
|
|
| 5 |
from io import BytesIO
|
| 6 |
import time
|
| 7 |
from typing import TYPE_CHECKING, Annotated, Literal
|
|
@@ -45,7 +46,7 @@ from faster_whisper_server.server_models import (
|
|
| 45 |
from faster_whisper_server.transcriber import audio_transcriber
|
| 46 |
|
| 47 |
if TYPE_CHECKING:
|
| 48 |
-
from collections.abc import Generator, Iterable
|
| 49 |
|
| 50 |
from faster_whisper.transcribe import TranscriptionInfo
|
| 51 |
from huggingface_hub.hf_api import ModelInfo
|
|
@@ -63,7 +64,7 @@ def load_model(model_name: str) -> WhisperModel:
|
|
| 63 |
del loaded_models[oldest_model_name]
|
| 64 |
logger.debug(f"Loading {model_name}...")
|
| 65 |
start = time.perf_counter()
|
| 66 |
-
# NOTE: will raise an exception if the model name isn't valid
|
| 67 |
whisper = WhisperModel(
|
| 68 |
model_name,
|
| 69 |
device=config.whisper.inference_device,
|
|
@@ -81,7 +82,15 @@ def load_model(model_name: str) -> WhisperModel:
|
|
| 81 |
|
| 82 |
logger.debug(f"Config: {config}")
|
| 83 |
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
if config.allow_origins is not None:
|
| 87 |
app.add_middleware(
|
|
|
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
from collections import OrderedDict
|
| 5 |
+
from contextlib import asynccontextmanager
|
| 6 |
from io import BytesIO
|
| 7 |
import time
|
| 8 |
from typing import TYPE_CHECKING, Annotated, Literal
|
|
|
|
| 46 |
from faster_whisper_server.transcriber import audio_transcriber
|
| 47 |
|
| 48 |
if TYPE_CHECKING:
|
| 49 |
+
from collections.abc import AsyncGenerator, Generator, Iterable
|
| 50 |
|
| 51 |
from faster_whisper.transcribe import TranscriptionInfo
|
| 52 |
from huggingface_hub.hf_api import ModelInfo
|
|
|
|
| 64 |
del loaded_models[oldest_model_name]
|
| 65 |
logger.debug(f"Loading {model_name}...")
|
| 66 |
start = time.perf_counter()
|
| 67 |
+
# NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
|
| 68 |
whisper = WhisperModel(
|
| 69 |
model_name,
|
| 70 |
device=config.whisper.inference_device,
|
|
|
|
| 82 |
|
| 83 |
logger.debug(f"Config: {config}")
|
| 84 |
|
| 85 |
+
|
| 86 |
+
@asynccontextmanager
|
| 87 |
+
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
| 88 |
+
for model_name in config.preload_models:
|
| 89 |
+
load_model(model_name)
|
| 90 |
+
yield
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
app = FastAPI(lifespan=lifespan)
|
| 94 |
|
| 95 |
if config.allow_origins is not None:
|
| 96 |
app.add_middleware(
|