Spaces:
				
			
			
	
			
			
		Configuration error
		
	
	
	
			
			
	
	
	
	
		
		
		Configuration error
		
	
		Fedir Zadniprovskyi
		
	commited on
		
		
					Commit 
							
							·
						
						35eafc3
	
1
								Parent(s):
							
							23a3cae
								
feat: model unloading
Browse files- pyproject.toml +1 -1
- src/faster_whisper_server/config.py +9 -16
- src/faster_whisper_server/dependencies.py +1 -1
- src/faster_whisper_server/model_manager.py +114 -30
- src/faster_whisper_server/routers/misc.py +10 -8
- src/faster_whisper_server/routers/stt.py +48 -48
- tests/model_manager_test.py +122 -0
- uv.lock +14 -0
    	
        pyproject.toml
    CHANGED
    
    | @@ -19,10 +19,10 @@ dependencies = [ | |
| 19 | 
             
            client = [
         | 
| 20 | 
             
                "keyboard>=0.13.5",
         | 
| 21 | 
             
            ]
         | 
| 22 | 
            -
            # NOTE: when installing `dev` group, all other groups should also be installed
         | 
| 23 | 
             
            dev = [
         | 
| 24 | 
             
                "anyio>=4.4.0",
         | 
| 25 | 
             
                "basedpyright>=1.18.0",
         | 
|  | |
| 26 | 
             
                "pytest-asyncio>=0.24.0",
         | 
| 27 | 
             
                "pytest-xdist>=3.6.1",
         | 
| 28 | 
             
                "pytest>=8.3.3",
         | 
|  | |
| 19 | 
             
            client = [
         | 
| 20 | 
             
                "keyboard>=0.13.5",
         | 
| 21 | 
             
            ]
         | 
|  | |
| 22 | 
             
            dev = [
         | 
| 23 | 
             
                "anyio>=4.4.0",
         | 
| 24 | 
             
                "basedpyright>=1.18.0",
         | 
| 25 | 
            +
                "pytest-antilru>=2.0.0",
         | 
| 26 | 
             
                "pytest-asyncio>=0.24.0",
         | 
| 27 | 
             
                "pytest-xdist>=3.6.1",
         | 
| 28 | 
             
                "pytest>=8.3.3",
         | 
    	
        src/faster_whisper_server/config.py
    CHANGED
    
    | @@ -1,7 +1,6 @@ | |
| 1 | 
             
            import enum
         | 
| 2 | 
            -
            from typing import Self
         | 
| 3 |  | 
| 4 | 
            -
            from pydantic import BaseModel, Field | 
| 5 | 
             
            from pydantic_settings import BaseSettings, SettingsConfigDict
         | 
| 6 |  | 
| 7 | 
             
            SAMPLES_PER_SECOND = 16000
         | 
| @@ -163,6 +162,12 @@ class WhisperConfig(BaseModel): | |
| 163 | 
             
                compute_type: Quantization = Field(default=Quantization.DEFAULT)
         | 
| 164 | 
             
                cpu_threads: int = 0
         | 
| 165 | 
             
                num_workers: int = 1
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 166 |  | 
| 167 |  | 
| 168 | 
             
            class Config(BaseSettings):
         | 
| @@ -198,10 +203,6 @@ class Config(BaseSettings): | |
| 198 | 
             
                """
         | 
| 199 | 
             
                default_response_format: ResponseFormat = ResponseFormat.JSON
         | 
| 200 | 
             
                whisper: WhisperConfig = WhisperConfig()
         | 
| 201 | 
            -
                max_models: int = 1
         | 
| 202 | 
            -
                """
         | 
| 203 | 
            -
                Maximum number of models that can be loaded at a time.
         | 
| 204 | 
            -
                """
         | 
| 205 | 
             
                preload_models: list[str] = Field(
         | 
| 206 | 
             
                    default_factory=list,
         | 
| 207 | 
             
                    examples=[
         | 
| @@ -210,8 +211,8 @@ class Config(BaseSettings): | |
| 210 | 
             
                    ],
         | 
| 211 | 
             
                )
         | 
| 212 | 
             
                """
         | 
| 213 | 
            -
                List of models to preload on startup.  | 
| 214 | 
            -
                """ | 
| 215 | 
             
                max_no_data_seconds: float = 1.0
         | 
| 216 | 
             
                """
         | 
| 217 | 
             
                Max duration to wait for the next audio chunk before transcription is finilized and connection is closed.
         | 
| @@ -230,11 +231,3 @@ class Config(BaseSettings): | |
| 230 | 
             
                Controls how many latest seconds of audio are being passed through VAD.
         | 
| 231 | 
             
                Should be greater than `max_inactivity_seconds`
         | 
| 232 | 
             
                """
         | 
| 233 | 
            -
             | 
| 234 | 
            -
                @model_validator(mode="after")
         | 
| 235 | 
            -
                def ensure_preloaded_models_is_lte_max_models(self) -> Self:
         | 
| 236 | 
            -
                    if len(self.preload_models) > self.max_models:
         | 
| 237 | 
            -
                        raise ValueError(
         | 
| 238 | 
            -
                            f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})"  # noqa: E501
         | 
| 239 | 
            -
                        )
         | 
| 240 | 
            -
                    return self
         | 
|  | |
| 1 | 
             
            import enum
         | 
|  | |
| 2 |  | 
| 3 | 
            +
            from pydantic import BaseModel, Field
         | 
| 4 | 
             
            from pydantic_settings import BaseSettings, SettingsConfigDict
         | 
| 5 |  | 
| 6 | 
             
            SAMPLES_PER_SECOND = 16000
         | 
|  | |
| 162 | 
             
                compute_type: Quantization = Field(default=Quantization.DEFAULT)
         | 
| 163 | 
             
                cpu_threads: int = 0
         | 
| 164 | 
             
                num_workers: int = 1
         | 
| 165 | 
            +
                ttl: int = Field(default=300, ge=-1)
         | 
| 166 | 
            +
                """
         | 
| 167 | 
            +
                Time in seconds until the model is unloaded if it is not being used.
         | 
| 168 | 
            +
                -1: Never unload the model.
         | 
| 169 | 
            +
                0: Unload the model immediately after usage.
         | 
| 170 | 
            +
                """
         | 
| 171 |  | 
| 172 |  | 
| 173 | 
             
            class Config(BaseSettings):
         | 
|  | |
| 203 | 
             
                """
         | 
| 204 | 
             
                default_response_format: ResponseFormat = ResponseFormat.JSON
         | 
| 205 | 
             
                whisper: WhisperConfig = WhisperConfig()
         | 
|  | |
|  | |
|  | |
|  | |
| 206 | 
             
                preload_models: list[str] = Field(
         | 
| 207 | 
             
                    default_factory=list,
         | 
| 208 | 
             
                    examples=[
         | 
|  | |
| 211 | 
             
                    ],
         | 
| 212 | 
             
                )
         | 
| 213 | 
             
                """
         | 
| 214 | 
            +
                List of models to preload on startup. By default, the model is first loaded on first request.
         | 
| 215 | 
            +
                """
         | 
| 216 | 
             
                max_no_data_seconds: float = 1.0
         | 
| 217 | 
             
                """
         | 
| 218 | 
             
                Max duration to wait for the next audio chunk before transcription is finilized and connection is closed.
         | 
|  | |
| 231 | 
             
                Controls how many latest seconds of audio are being passed through VAD.
         | 
| 232 | 
             
                Should be greater than `max_inactivity_seconds`
         | 
| 233 | 
             
                """
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        src/faster_whisper_server/dependencies.py
    CHANGED
    
    | @@ -18,7 +18,7 @@ ConfigDependency = Annotated[Config, Depends(get_config)] | |
| 18 | 
             
            @lru_cache
         | 
| 19 | 
             
            def get_model_manager() -> ModelManager:
         | 
| 20 | 
             
                config = get_config()  # HACK
         | 
| 21 | 
            -
                return ModelManager(config)
         | 
| 22 |  | 
| 23 |  | 
| 24 | 
             
            ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)]
         | 
|  | |
| 18 | 
             
            @lru_cache
         | 
| 19 | 
             
            def get_model_manager() -> ModelManager:
         | 
| 20 | 
             
                config = get_config()  # HACK
         | 
| 21 | 
            +
                return ModelManager(config.whisper)
         | 
| 22 |  | 
| 23 |  | 
| 24 | 
             
            ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)]
         | 
    	
        src/faster_whisper_server/model_manager.py
    CHANGED
    
    | @@ -3,48 +3,132 @@ from __future__ import annotations | |
| 3 | 
             
            from collections import OrderedDict
         | 
| 4 | 
             
            import gc
         | 
| 5 | 
             
            import logging
         | 
|  | |
| 6 | 
             
            import time
         | 
| 7 | 
             
            from typing import TYPE_CHECKING
         | 
| 8 |  | 
| 9 | 
             
            from faster_whisper import WhisperModel
         | 
| 10 |  | 
| 11 | 
             
            if TYPE_CHECKING:
         | 
|  | |
|  | |
| 12 | 
             
                from faster_whisper_server.config import (
         | 
| 13 | 
            -
                     | 
| 14 | 
             
                )
         | 
| 15 |  | 
| 16 | 
             
            logger = logging.getLogger(__name__)
         | 
| 17 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 18 |  | 
| 19 | 
             
            class ModelManager:
         | 
| 20 | 
            -
                def __init__(self,  | 
| 21 | 
            -
                    self. | 
| 22 | 
            -
                    self.loaded_models: OrderedDict[str,  | 
|  | |
| 23 |  | 
| 24 | 
            -
                def  | 
| 25 | 
            -
                     | 
| 26 | 
            -
                         | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 32 | 
             
                        )
         | 
| 33 | 
            -
                         | 
| 34 | 
            -
                        gc.collect()
         | 
| 35 | 
            -
                    logger.debug(f"Loading {model_name}...")
         | 
| 36 | 
            -
                    start = time.perf_counter()
         | 
| 37 | 
            -
                    # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
         | 
| 38 | 
            -
                    whisper = WhisperModel(
         | 
| 39 | 
            -
                        model_name,
         | 
| 40 | 
            -
                        device=self.config.whisper.inference_device,
         | 
| 41 | 
            -
                        device_index=self.config.whisper.device_index,
         | 
| 42 | 
            -
                        compute_type=self.config.whisper.compute_type,
         | 
| 43 | 
            -
                        cpu_threads=self.config.whisper.cpu_threads,
         | 
| 44 | 
            -
                        num_workers=self.config.whisper.num_workers,
         | 
| 45 | 
            -
                    )
         | 
| 46 | 
            -
                    logger.info(
         | 
| 47 | 
            -
                        f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {self.config.whisper.inference_device}({self.config.whisper.compute_type}) will be used for inference."  # noqa: E501
         | 
| 48 | 
            -
                    )
         | 
| 49 | 
            -
                    self.loaded_models[model_name] = whisper
         | 
| 50 | 
            -
                    return whisper
         | 
|  | |
| 3 | 
             
            from collections import OrderedDict
         | 
| 4 | 
             
            import gc
         | 
| 5 | 
             
            import logging
         | 
| 6 | 
            +
            import threading
         | 
| 7 | 
             
            import time
         | 
| 8 | 
             
            from typing import TYPE_CHECKING
         | 
| 9 |  | 
| 10 | 
             
            from faster_whisper import WhisperModel
         | 
| 11 |  | 
| 12 | 
             
            if TYPE_CHECKING:
         | 
| 13 | 
            +
                from collections.abc import Callable
         | 
| 14 | 
            +
             | 
| 15 | 
             
                from faster_whisper_server.config import (
         | 
| 16 | 
            +
                    WhisperConfig,
         | 
| 17 | 
             
                )
         | 
| 18 |  | 
| 19 | 
             
            logger = logging.getLogger(__name__)
         | 
| 20 |  | 
| 21 | 
            +
            # TODO: enable concurrent model downloads
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class SelfDisposingWhisperModel:
         | 
| 25 | 
            +
                def __init__(
         | 
| 26 | 
            +
                    self,
         | 
| 27 | 
            +
                    model_id: str,
         | 
| 28 | 
            +
                    whisper_config: WhisperConfig,
         | 
| 29 | 
            +
                    *,
         | 
| 30 | 
            +
                    on_unload: Callable[[str], None] | None = None,
         | 
| 31 | 
            +
                ) -> None:
         | 
| 32 | 
            +
                    self.model_id = model_id
         | 
| 33 | 
            +
                    self.whisper_config = whisper_config
         | 
| 34 | 
            +
                    self.on_unload = on_unload
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    self.ref_count: int = 0
         | 
| 37 | 
            +
                    self.rlock = threading.RLock()
         | 
| 38 | 
            +
                    self.expire_timer: threading.Timer | None = None
         | 
| 39 | 
            +
                    self.whisper: WhisperModel | None = None
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def unload(self) -> None:
         | 
| 42 | 
            +
                    with self.rlock:
         | 
| 43 | 
            +
                        if self.whisper is None:
         | 
| 44 | 
            +
                            raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}")
         | 
| 45 | 
            +
                        if self.ref_count > 0:
         | 
| 46 | 
            +
                            raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}")
         | 
| 47 | 
            +
                        if self.expire_timer:
         | 
| 48 | 
            +
                            self.expire_timer.cancel()
         | 
| 49 | 
            +
                        self.whisper = None
         | 
| 50 | 
            +
                        # WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
         | 
| 51 | 
            +
                        gc.collect()
         | 
| 52 | 
            +
                        logger.info(f"Model {self.model_id} unloaded")
         | 
| 53 | 
            +
                        if self.on_unload is not None:
         | 
| 54 | 
            +
                            self.on_unload(self.model_id)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def _load(self) -> None:
         | 
| 57 | 
            +
                    with self.rlock:
         | 
| 58 | 
            +
                        assert self.whisper is None
         | 
| 59 | 
            +
                        logger.debug(f"Loading model {self.model_id}")
         | 
| 60 | 
            +
                        start = time.perf_counter()
         | 
| 61 | 
            +
                        self.whisper = WhisperModel(
         | 
| 62 | 
            +
                            self.model_id,
         | 
| 63 | 
            +
                            device=self.whisper_config.inference_device,
         | 
| 64 | 
            +
                            device_index=self.whisper_config.device_index,
         | 
| 65 | 
            +
                            compute_type=self.whisper_config.compute_type,
         | 
| 66 | 
            +
                            cpu_threads=self.whisper_config.cpu_threads,
         | 
| 67 | 
            +
                            num_workers=self.whisper_config.num_workers,
         | 
| 68 | 
            +
                        )
         | 
| 69 | 
            +
                        logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s")
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def _increment_ref(self) -> None:
         | 
| 72 | 
            +
                    with self.rlock:
         | 
| 73 | 
            +
                        self.ref_count += 1
         | 
| 74 | 
            +
                        if self.expire_timer:
         | 
| 75 | 
            +
                            logger.debug(f"Model was set to expire in {self.expire_timer.interval}s, cancelling")
         | 
| 76 | 
            +
                            self.expire_timer.cancel()
         | 
| 77 | 
            +
                        logger.debug(f"Incremented ref count for {self.model_id}, {self.ref_count=}")
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def _decrement_ref(self) -> None:
         | 
| 80 | 
            +
                    with self.rlock:
         | 
| 81 | 
            +
                        self.ref_count -= 1
         | 
| 82 | 
            +
                        logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}")
         | 
| 83 | 
            +
                        if self.ref_count <= 0:
         | 
| 84 | 
            +
                            if self.whisper_config.ttl > 0:
         | 
| 85 | 
            +
                                logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.whisper_config.ttl}s")
         | 
| 86 | 
            +
                                self.expire_timer = threading.Timer(self.whisper_config.ttl, self.unload)
         | 
| 87 | 
            +
                                self.expire_timer.start()
         | 
| 88 | 
            +
                            elif self.whisper_config.ttl == 0:
         | 
| 89 | 
            +
                                logger.info(f"Model {self.model_id} is idle, unloading immediately")
         | 
| 90 | 
            +
                                self.unload()
         | 
| 91 | 
            +
                            else:
         | 
| 92 | 
            +
                                logger.info(f"Model {self.model_id} is idle, not unloading")
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def __enter__(self) -> WhisperModel:
         | 
| 95 | 
            +
                    with self.rlock:
         | 
| 96 | 
            +
                        if self.whisper is None:
         | 
| 97 | 
            +
                            self._load()
         | 
| 98 | 
            +
                        self._increment_ref()
         | 
| 99 | 
            +
                        assert self.whisper is not None
         | 
| 100 | 
            +
                        return self.whisper
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def __exit__(self, *_args) -> None:  # noqa: ANN002
         | 
| 103 | 
            +
                    self._decrement_ref()
         | 
| 104 | 
            +
             | 
| 105 |  | 
| 106 | 
             
            class ModelManager:
         | 
| 107 | 
            +
                def __init__(self, whisper_config: WhisperConfig) -> None:
         | 
| 108 | 
            +
                    self.whisper_config = whisper_config
         | 
| 109 | 
            +
                    self.loaded_models: OrderedDict[str, SelfDisposingWhisperModel] = OrderedDict()
         | 
| 110 | 
            +
                    self._lock = threading.Lock()
         | 
| 111 |  | 
| 112 | 
            +
                def _handle_model_unload(self, model_name: str) -> None:
         | 
| 113 | 
            +
                    with self._lock:
         | 
| 114 | 
            +
                        if model_name in self.loaded_models:
         | 
| 115 | 
            +
                            del self.loaded_models[model_name]
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                def unload_model(self, model_name: str) -> None:
         | 
| 118 | 
            +
                    with self._lock:
         | 
| 119 | 
            +
                        model = self.loaded_models.get(model_name)
         | 
| 120 | 
            +
                        if model is None:
         | 
| 121 | 
            +
                            raise KeyError(f"Model {model_name} not found")
         | 
| 122 | 
            +
                        self.loaded_models[model_name].unload()
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                def load_model(self, model_name: str) -> SelfDisposingWhisperModel:
         | 
| 125 | 
            +
                    with self._lock:
         | 
| 126 | 
            +
                        if model_name in self.loaded_models:
         | 
| 127 | 
            +
                            logger.debug(f"{model_name} model already loaded")
         | 
| 128 | 
            +
                            return self.loaded_models[model_name]
         | 
| 129 | 
            +
                        self.loaded_models[model_name] = SelfDisposingWhisperModel(
         | 
| 130 | 
            +
                            model_name,
         | 
| 131 | 
            +
                            self.whisper_config,
         | 
| 132 | 
            +
                            on_unload=self._handle_model_unload,
         | 
| 133 | 
             
                        )
         | 
| 134 | 
            +
                        return self.loaded_models[model_name]
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        src/faster_whisper_server/routers/misc.py
    CHANGED
    
    | @@ -1,7 +1,5 @@ | |
| 1 | 
             
            from __future__ import annotations
         | 
| 2 |  | 
| 3 | 
            -
            import gc
         | 
| 4 | 
            -
             | 
| 5 | 
             
            from fastapi import (
         | 
| 6 | 
             
                APIRouter,
         | 
| 7 | 
             
                Response,
         | 
| @@ -42,15 +40,19 @@ def get_running_models( | |
| 42 | 
             
            def load_model_route(model_manager: ModelManagerDependency, model_name: str) -> Response:
         | 
| 43 | 
             
                if model_name in model_manager.loaded_models:
         | 
| 44 | 
             
                    return Response(status_code=409, content="Model already loaded")
         | 
| 45 | 
            -
                model_manager.load_model(model_name)
         | 
|  | |
| 46 | 
             
                return Response(status_code=201)
         | 
| 47 |  | 
| 48 |  | 
| 49 | 
             
            @router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
         | 
| 50 | 
             
            def stop_running_model(model_manager: ModelManagerDependency, model_name: str) -> Response:
         | 
| 51 | 
            -
                 | 
| 52 | 
            -
             | 
| 53 | 
            -
                    del model_manager.loaded_models[model_name]
         | 
| 54 | 
            -
                    gc.collect()
         | 
| 55 | 
             
                    return Response(status_code=204)
         | 
| 56 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            from __future__ import annotations
         | 
| 2 |  | 
|  | |
|  | |
| 3 | 
             
            from fastapi import (
         | 
| 4 | 
             
                APIRouter,
         | 
| 5 | 
             
                Response,
         | 
|  | |
| 40 | 
             
            def load_model_route(model_manager: ModelManagerDependency, model_name: str) -> Response:
         | 
| 41 | 
             
                if model_name in model_manager.loaded_models:
         | 
| 42 | 
             
                    return Response(status_code=409, content="Model already loaded")
         | 
| 43 | 
            +
                with model_manager.load_model(model_name):
         | 
| 44 | 
            +
                    pass
         | 
| 45 | 
             
                return Response(status_code=201)
         | 
| 46 |  | 
| 47 |  | 
| 48 | 
             
            @router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
         | 
| 49 | 
             
            def stop_running_model(model_manager: ModelManagerDependency, model_name: str) -> Response:
         | 
| 50 | 
            +
                try:
         | 
| 51 | 
            +
                    model_manager.unload_model(model_name)
         | 
|  | |
|  | |
| 52 | 
             
                    return Response(status_code=204)
         | 
| 53 | 
            +
                except (KeyError, ValueError) as e:
         | 
| 54 | 
            +
                    match e:
         | 
| 55 | 
            +
                        case KeyError():
         | 
| 56 | 
            +
                            return Response(status_code=404, content="Model not found")
         | 
| 57 | 
            +
                        case ValueError():
         | 
| 58 | 
            +
                            return Response(status_code=409, content=str(e))
         | 
    	
        src/faster_whisper_server/routers/stt.py
    CHANGED
    
    | @@ -142,20 +142,20 @@ def translate_file( | |
| 142 | 
             
                    model = config.whisper.model
         | 
| 143 | 
             
                if response_format is None:
         | 
| 144 | 
             
                    response_format = config.default_response_format
         | 
| 145 | 
            -
                 | 
| 146 | 
            -
             | 
| 147 | 
            -
             | 
| 148 | 
            -
             | 
| 149 | 
            -
             | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
| 152 | 
            -
             | 
| 153 | 
            -
             | 
| 154 | 
            -
             | 
| 155 | 
            -
             | 
| 156 | 
            -
             | 
| 157 | 
            -
             | 
| 158 | 
            -
             | 
| 159 |  | 
| 160 |  | 
| 161 | 
             
            # HACK: Since Form() doesn't support `alias`, we need to use a workaround.
         | 
| @@ -206,23 +206,23 @@ def transcribe_file( | |
| 206 | 
             
                    logger.warning(
         | 
| 207 | 
             
                        "It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities."  # noqa: E501
         | 
| 208 | 
             
                    )
         | 
| 209 | 
            -
                 | 
| 210 | 
            -
             | 
| 211 | 
            -
             | 
| 212 | 
            -
             | 
| 213 | 
            -
             | 
| 214 | 
            -
             | 
| 215 | 
            -
             | 
| 216 | 
            -
             | 
| 217 | 
            -
             | 
| 218 | 
            -
             | 
| 219 | 
            -
             | 
| 220 | 
            -
             | 
| 221 | 
            -
             | 
| 222 | 
            -
             | 
| 223 | 
            -
             | 
| 224 | 
            -
             | 
| 225 | 
            -
             | 
| 226 |  | 
| 227 |  | 
| 228 | 
             
            async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
         | 
| @@ -280,24 +280,24 @@ async def transcribe_stream( | |
| 280 | 
             
                    "vad_filter": vad_filter,
         | 
| 281 | 
             
                    "condition_on_previous_text": False,
         | 
| 282 | 
             
                }
         | 
| 283 | 
            -
                 | 
| 284 | 
            -
             | 
| 285 | 
            -
             | 
| 286 | 
            -
             | 
| 287 | 
            -
             | 
| 288 | 
            -
             | 
| 289 | 
            -
             | 
| 290 | 
            -
             | 
| 291 | 
            -
             | 
| 292 |  | 
| 293 | 
            -
             | 
| 294 | 
            -
             | 
| 295 | 
            -
             | 
| 296 | 
            -
             | 
| 297 | 
            -
             | 
| 298 | 
            -
             | 
| 299 | 
            -
             | 
| 300 | 
            -
             | 
| 301 |  | 
| 302 | 
             
                if ws.client_state != WebSocketState.DISCONNECTED:
         | 
| 303 | 
             
                    logger.info("Closing the connection.")
         | 
|  | |
| 142 | 
             
                    model = config.whisper.model
         | 
| 143 | 
             
                if response_format is None:
         | 
| 144 | 
             
                    response_format = config.default_response_format
         | 
| 145 | 
            +
                with model_manager.load_model(model) as whisper:
         | 
| 146 | 
            +
                    segments, transcription_info = whisper.transcribe(
         | 
| 147 | 
            +
                        file.file,
         | 
| 148 | 
            +
                        task=Task.TRANSLATE,
         | 
| 149 | 
            +
                        initial_prompt=prompt,
         | 
| 150 | 
            +
                        temperature=temperature,
         | 
| 151 | 
            +
                        vad_filter=vad_filter,
         | 
| 152 | 
            +
                    )
         | 
| 153 | 
            +
                    segments = TranscriptionSegment.from_faster_whisper_segments(segments)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    if stream:
         | 
| 156 | 
            +
                        return segments_to_streaming_response(segments, transcription_info, response_format)
         | 
| 157 | 
            +
                    else:
         | 
| 158 | 
            +
                        return segments_to_response(segments, transcription_info, response_format)
         | 
| 159 |  | 
| 160 |  | 
| 161 | 
             
            # HACK: Since Form() doesn't support `alias`, we need to use a workaround.
         | 
|  | |
| 206 | 
             
                    logger.warning(
         | 
| 207 | 
             
                        "It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities."  # noqa: E501
         | 
| 208 | 
             
                    )
         | 
| 209 | 
            +
                with model_manager.load_model(model) as whisper:
         | 
| 210 | 
            +
                    segments, transcription_info = whisper.transcribe(
         | 
| 211 | 
            +
                        file.file,
         | 
| 212 | 
            +
                        task=Task.TRANSCRIBE,
         | 
| 213 | 
            +
                        language=language,
         | 
| 214 | 
            +
                        initial_prompt=prompt,
         | 
| 215 | 
            +
                        word_timestamps="word" in timestamp_granularities,
         | 
| 216 | 
            +
                        temperature=temperature,
         | 
| 217 | 
            +
                        vad_filter=vad_filter,
         | 
| 218 | 
            +
                        hotwords=hotwords,
         | 
| 219 | 
            +
                    )
         | 
| 220 | 
            +
                    segments = TranscriptionSegment.from_faster_whisper_segments(segments)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    if stream:
         | 
| 223 | 
            +
                        return segments_to_streaming_response(segments, transcription_info, response_format)
         | 
| 224 | 
            +
                    else:
         | 
| 225 | 
            +
                        return segments_to_response(segments, transcription_info, response_format)
         | 
| 226 |  | 
| 227 |  | 
| 228 | 
             
            async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
         | 
|  | |
| 280 | 
             
                    "vad_filter": vad_filter,
         | 
| 281 | 
             
                    "condition_on_previous_text": False,
         | 
| 282 | 
             
                }
         | 
| 283 | 
            +
                with model_manager.load_model(model) as whisper:
         | 
| 284 | 
            +
                    asr = FasterWhisperASR(whisper, **transcribe_opts)
         | 
| 285 | 
            +
                    audio_stream = AudioStream()
         | 
| 286 | 
            +
                    async with asyncio.TaskGroup() as tg:
         | 
| 287 | 
            +
                        tg.create_task(audio_receiver(ws, audio_stream))
         | 
| 288 | 
            +
                        async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration):
         | 
| 289 | 
            +
                            logger.debug(f"Sending transcription: {transcription.text}")
         | 
| 290 | 
            +
                            if ws.client_state == WebSocketState.DISCONNECTED:
         | 
| 291 | 
            +
                                break
         | 
| 292 |  | 
| 293 | 
            +
                            if response_format == ResponseFormat.TEXT:
         | 
| 294 | 
            +
                                await ws.send_text(transcription.text)
         | 
| 295 | 
            +
                            elif response_format == ResponseFormat.JSON:
         | 
| 296 | 
            +
                                await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump())
         | 
| 297 | 
            +
                            elif response_format == ResponseFormat.VERBOSE_JSON:
         | 
| 298 | 
            +
                                await ws.send_json(
         | 
| 299 | 
            +
                                    CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump()
         | 
| 300 | 
            +
                                )
         | 
| 301 |  | 
| 302 | 
             
                if ws.client_state != WebSocketState.DISCONNECTED:
         | 
| 303 | 
             
                    logger.info("Closing the connection.")
         | 
    	
        tests/model_manager_test.py
    ADDED
    
    | @@ -0,0 +1,122 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import asyncio
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import anyio
         | 
| 5 | 
            +
            from httpx import ASGITransport, AsyncClient
         | 
| 6 | 
            +
            import pytest
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from faster_whisper_server.main import create_app
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            @pytest.mark.asyncio
         | 
| 12 | 
            +
            async def test_model_unloaded_after_ttl() -> None:
         | 
| 13 | 
            +
                ttl = 5
         | 
| 14 | 
            +
                model = "Systran/faster-whisper-tiny.en"
         | 
| 15 | 
            +
                os.environ["WHISPER__TTL"] = str(ttl)
         | 
| 16 | 
            +
                os.environ["ENABLE_UI"] = "false"
         | 
| 17 | 
            +
                async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
         | 
| 18 | 
            +
                    res = (await aclient.get("/api/ps")).json()
         | 
| 19 | 
            +
                    assert len(res["models"]) == 0
         | 
| 20 | 
            +
                    await aclient.post(f"/api/ps/{model}")
         | 
| 21 | 
            +
                    res = (await aclient.get("/api/ps")).json()
         | 
| 22 | 
            +
                    assert len(res["models"]) == 1
         | 
| 23 | 
            +
                    await asyncio.sleep(ttl + 1)
         | 
| 24 | 
            +
                    res = (await aclient.get("/api/ps")).json()
         | 
| 25 | 
            +
                    assert len(res["models"]) == 0
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            @pytest.mark.asyncio
         | 
| 29 | 
            +
            async def test_ttl_resets_after_usage() -> None:
         | 
| 30 | 
            +
                ttl = 5
         | 
| 31 | 
            +
                model = "Systran/faster-whisper-tiny.en"
         | 
| 32 | 
            +
                os.environ["WHISPER__TTL"] = str(ttl)
         | 
| 33 | 
            +
                os.environ["ENABLE_UI"] = "false"
         | 
| 34 | 
            +
                async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
         | 
| 35 | 
            +
                    await aclient.post(f"/api/ps/{model}")
         | 
| 36 | 
            +
                    res = (await aclient.get("/api/ps")).json()
         | 
| 37 | 
            +
                    assert len(res["models"]) == 1
         | 
| 38 | 
            +
                    await asyncio.sleep(ttl - 2)
         | 
| 39 | 
            +
                    res = (await aclient.get("/api/ps")).json()
         | 
| 40 | 
            +
                    assert len(res["models"]) == 1
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    async with await anyio.open_file("audio.wav", "rb") as f:
         | 
| 43 | 
            +
                        data = await f.read()
         | 
| 44 | 
            +
                    res = (
         | 
| 45 | 
            +
                        await aclient.post(
         | 
| 46 | 
            +
                            "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
         | 
| 47 | 
            +
                        )
         | 
| 48 | 
            +
                    ).json()
         | 
| 49 | 
            +
                    res = (await aclient.get("/api/ps")).json()
         | 
| 50 | 
            +
                    assert len(res["models"]) == 1
         | 
| 51 | 
            +
                    await asyncio.sleep(ttl - 2)
         | 
| 52 | 
            +
                    res = (await aclient.get("/api/ps")).json()
         | 
| 53 | 
            +
                    assert len(res["models"]) == 1
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    await asyncio.sleep(3)
         | 
| 56 | 
            +
                    res = (await aclient.get("/api/ps")).json()
         | 
| 57 | 
            +
                    assert len(res["models"]) == 0
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    # test the model can be used again after being unloaded
         | 
| 60 | 
            +
                    # this just ensures the model can be loaded again after being unloaded
         | 
| 61 | 
            +
                    res = (
         | 
| 62 | 
            +
                        await aclient.post(
         | 
| 63 | 
            +
                            "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
         | 
| 64 | 
            +
                        )
         | 
| 65 | 
            +
                    ).json()
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            @pytest.mark.asyncio
         | 
| 69 | 
            +
            async def test_model_cant_be_unloaded_when_used() -> None:
         | 
| 70 | 
            +
                ttl = 0
         | 
| 71 | 
            +
                model = "Systran/faster-whisper-tiny.en"
         | 
| 72 | 
            +
                os.environ["WHISPER__TTL"] = str(ttl)
         | 
| 73 | 
            +
                os.environ["ENABLE_UI"] = "false"
         | 
| 74 | 
            +
                async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
         | 
| 75 | 
            +
                    async with await anyio.open_file("audio.wav", "rb") as f:
         | 
| 76 | 
            +
                        data = await f.read()
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    task = asyncio.create_task(
         | 
| 79 | 
            +
                        aclient.post(
         | 
| 80 | 
            +
                            "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
         | 
| 81 | 
            +
                        )
         | 
| 82 | 
            +
                    )
         | 
| 83 | 
            +
                    await asyncio.sleep(0.01)
         | 
| 84 | 
            +
                    res = await aclient.delete(f"/api/ps/{model}")
         | 
| 85 | 
            +
                    assert res.status_code == 409
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    await task
         | 
| 88 | 
            +
                    res = (await aclient.get("/api/ps")).json()
         | 
| 89 | 
            +
                    assert len(res["models"]) == 0
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            @pytest.mark.asyncio
         | 
| 93 | 
            +
            async def test_model_cant_be_loaded_twice() -> None:
         | 
| 94 | 
            +
                ttl = -1
         | 
| 95 | 
            +
                model = "Systran/faster-whisper-tiny.en"
         | 
| 96 | 
            +
                os.environ["ENABLE_UI"] = "false"
         | 
| 97 | 
            +
                os.environ["WHISPER__TTL"] = str(ttl)
         | 
| 98 | 
            +
                async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
         | 
| 99 | 
            +
                    res = await aclient.post(f"/api/ps/{model}")
         | 
| 100 | 
            +
                    assert res.status_code == 201
         | 
| 101 | 
            +
                    res = await aclient.post(f"/api/ps/{model}")
         | 
| 102 | 
            +
                    assert res.status_code == 409
         | 
| 103 | 
            +
                    res = (await aclient.get("/api/ps")).json()
         | 
| 104 | 
            +
                    assert len(res["models"]) == 1
         | 
| 105 | 
            +
             | 
| 106 | 
            +
             | 
| 107 | 
            +
            @pytest.mark.asyncio
         | 
| 108 | 
            +
            async def test_model_is_unloaded_after_request_when_ttl_is_zero() -> None:
         | 
| 109 | 
            +
                ttl = 0
         | 
| 110 | 
            +
                os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
         | 
| 111 | 
            +
                os.environ["WHISPER__TTL"] = str(ttl)
         | 
| 112 | 
            +
                os.environ["ENABLE_UI"] = "false"
         | 
| 113 | 
            +
                async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
         | 
| 114 | 
            +
                    async with await anyio.open_file("audio.wav", "rb") as f:
         | 
| 115 | 
            +
                        data = await f.read()
         | 
| 116 | 
            +
                    res = await aclient.post(
         | 
| 117 | 
            +
                        "/v1/audio/transcriptions",
         | 
| 118 | 
            +
                        files={"file": ("audio.wav", data, "audio/wav")},
         | 
| 119 | 
            +
                        data={"model": "Systran/faster-whisper-tiny.en"},
         | 
| 120 | 
            +
                    )
         | 
| 121 | 
            +
                    res = (await aclient.get("/api/ps")).json()
         | 
| 122 | 
            +
                    assert len(res["models"]) == 0
         | 
    	
        uv.lock
    CHANGED
    
    | @@ -293,6 +293,7 @@ dev = [ | |
| 293 | 
             
                { name = "anyio" },
         | 
| 294 | 
             
                { name = "basedpyright" },
         | 
| 295 | 
             
                { name = "pytest" },
         | 
|  | |
| 296 | 
             
                { name = "pytest-asyncio" },
         | 
| 297 | 
             
                { name = "pytest-xdist" },
         | 
| 298 | 
             
                { name = "ruff" },
         | 
| @@ -322,6 +323,7 @@ requires-dist = [ | |
| 322 | 
             
                { name = "pydantic", specifier = ">=2.9.0" },
         | 
| 323 | 
             
                { name = "pydantic-settings", specifier = ">=2.5.2" },
         | 
| 324 | 
             
                { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.3" },
         | 
|  | |
| 325 | 
             
                { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" },
         | 
| 326 | 
             
                { name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.6.1" },
         | 
| 327 | 
             
                { name = "python-multipart", specifier = ">=0.0.10" },
         | 
| @@ -3482,6 +3484,18 @@ wheels = [ | |
| 3482 | 
             
                { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 },
         | 
| 3483 | 
             
            ]
         | 
| 3484 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 3485 | 
             
            [[package]]
         | 
| 3486 | 
             
            name = "pytest-asyncio"
         | 
| 3487 | 
             
            version = "0.24.0"
         | 
|  | |
| 293 | 
             
                { name = "anyio" },
         | 
| 294 | 
             
                { name = "basedpyright" },
         | 
| 295 | 
             
                { name = "pytest" },
         | 
| 296 | 
            +
                { name = "pytest-antilru" },
         | 
| 297 | 
             
                { name = "pytest-asyncio" },
         | 
| 298 | 
             
                { name = "pytest-xdist" },
         | 
| 299 | 
             
                { name = "ruff" },
         | 
|  | |
| 323 | 
             
                { name = "pydantic", specifier = ">=2.9.0" },
         | 
| 324 | 
             
                { name = "pydantic-settings", specifier = ">=2.5.2" },
         | 
| 325 | 
             
                { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.3" },
         | 
| 326 | 
            +
                { name = "pytest-antilru", marker = "extra == 'dev'", specifier = ">=2.0.0" },
         | 
| 327 | 
             
                { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" },
         | 
| 328 | 
             
                { name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.6.1" },
         | 
| 329 | 
             
                { name = "python-multipart", specifier = ">=0.0.10" },
         | 
|  | |
| 3484 | 
             
                { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 },
         | 
| 3485 | 
             
            ]
         | 
| 3486 |  | 
| 3487 | 
            +
            [[package]]
         | 
| 3488 | 
            +
            name = "pytest-antilru"
         | 
| 3489 | 
            +
            version = "2.0.0"
         | 
| 3490 | 
            +
            source = { registry = "https://pypi.org/simple" }
         | 
| 3491 | 
            +
            dependencies = [
         | 
| 3492 | 
            +
                { name = "pytest" },
         | 
| 3493 | 
            +
            ]
         | 
| 3494 | 
            +
            sdist = { url = "https://files.pythonhosted.org/packages/c6/01/0b5ef3f143f335b5cb1c1e8e6497769dfb48aed5a791b5dfd119151e2b15/pytest_antilru-2.0.0.tar.gz", hash = "sha256:48cff342648b6a1ce4e5398cf203966905d546b3f2bee7bb55d7cb3ec87a85fb", size = 5569 }
         | 
| 3495 | 
            +
            wheels = [
         | 
| 3496 | 
            +
                { url = "https://files.pythonhosted.org/packages/23/f0/fc9f5aaaf2818a7d7f795e99fcf59719dd6ec5f98005e642e1efd63ad2a4/pytest_antilru-2.0.0-py3-none-any.whl", hash = "sha256:cf1d97db0e7b17ef568c1f0bf4c89b8748053fe07546f4eb2558bebf64c1ad33", size = 6301 },
         | 
| 3497 | 
            +
            ]
         | 
| 3498 | 
            +
             | 
| 3499 | 
             
            [[package]]
         | 
| 3500 | 
             
            name = "pytest-asyncio"
         | 
| 3501 | 
             
            version = "0.24.0"
         |