Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from modules import config | |
| from modules import generate_audio as generate | |
| from functools import lru_cache | |
| from typing import Callable | |
| from modules.api.Api import APIManager | |
| from modules.api.impl import ( | |
| base_api, | |
| tts_api, | |
| ssml_api, | |
| google_api, | |
| openai_api, | |
| refiner_api, | |
| ) | |
| torch._dynamo.config.cache_size_limit = 64 | |
| torch._dynamo.config.suppress_errors = True | |
| torch.set_float32_matmul_precision("high") | |
| def create_api(): | |
| api = APIManager() | |
| base_api.setup(api) | |
| tts_api.setup(api) | |
| ssml_api.setup(api) | |
| google_api.setup(api) | |
| openai_api.setup(api) | |
| refiner_api.setup(api) | |
| return api | |
| def conditional_cache(condition: Callable): | |
| def decorator(func): | |
| def cached_func(*args, **kwargs): | |
| return func(*args, **kwargs) | |
| def wrapper(*args, **kwargs): | |
| if condition(*args, **kwargs): | |
| return cached_func(*args, **kwargs) | |
| else: | |
| return func(*args, **kwargs) | |
| return wrapper | |
| return decorator | |
| if __name__ == "__main__": | |
| import argparse | |
| import uvicorn | |
| parser = argparse.ArgumentParser( | |
| description="Start the FastAPI server with command line arguments" | |
| ) | |
| parser.add_argument( | |
| "--host", type=str, default="0.0.0.0", help="Host to run the server on" | |
| ) | |
| parser.add_argument( | |
| "--port", type=int, default=8000, help="Port to run the server on" | |
| ) | |
| parser.add_argument( | |
| "--reload", action="store_true", help="Enable auto-reload for development" | |
| ) | |
| parser.add_argument("--compile", action="store_true", help="Enable model compile") | |
| parser.add_argument( | |
| "--lru_size", | |
| type=int, | |
| default=64, | |
| help="Set the size of the request cache pool, set it to 0 will disable lru_cache", | |
| ) | |
| parser.add_argument( | |
| "--cors_origin", | |
| type=str, | |
| default="*", | |
| help="Allowed CORS origins. Use '*' to allow all origins.", | |
| ) | |
| args = parser.parse_args() | |
| config.args = args | |
| if args.compile: | |
| print("Model compile is enabled") | |
| config.enable_model_compile = True | |
| def should_cache(*args, **kwargs): | |
| spk_seed = kwargs.get("spk_seed", -1) | |
| infer_seed = kwargs.get("infer_seed", -1) | |
| return spk_seed != -1 and infer_seed != -1 | |
| if args.lru_size > 0: | |
| config.lru_size = args.lru_size | |
| generate.generate_audio = conditional_cache(should_cache)( | |
| generate.generate_audio | |
| ) | |
| api = create_api() | |
| config.api = api | |
| if args.cors_origin: | |
| api.set_cors(allow_origins=[args.cors_origin]) | |
| uvicorn.run(api.app, host=args.host, port=args.port, reload=args.reload) | |