|
|
|
""" |
|
Updated FastAPI backend for GPT-SoVITS (*April 2025*) |
|
--------------------------------------------------- |
|
Changes compared with the previous version shipped on 30 Apr 2025 |
|
================================================================= |
|
1. **URL / S3 audio support** — `process_audio_path()` downloads `ref_audio_path` and |
|
each entry in `aux_ref_audio_paths` when they are HTTP(S) or S3 URLs, storing them |
|
as temporary files that are cleaned up afterwards. |
|
2. **CUDA memory hygiene** — `torch.cuda.empty_cache()` is invoked after each request |
|
(success *or* error) to release GPU memory. |
|
3. **Temporary‑file cleanup** — all files created by `process_audio_path()` are |
|
removed in `finally` blocks so they are guaranteed to disappear no matter how the |
|
request terminates. |
|
|
|
The public API surface (**end‑points and query parameters**) is unchanged. |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import argparse |
|
import os |
|
import signal |
|
import subprocess |
|
import sys |
|
import traceback |
|
import urllib.parse |
|
from io import BytesIO |
|
from typing import Generator, List, Tuple |
|
|
|
import numpy as np |
|
import requests |
|
import soundfile as sf |
|
import torch |
|
import uvicorn |
|
from fastapi import FastAPI, HTTPException, Response |
|
from fastapi.responses import JSONResponse, StreamingResponse |
|
from pydantic import BaseModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
NOW_DIR = os.getcwd() |
|
sys.path.extend([NOW_DIR, f"{NOW_DIR}/GPT_SoVITS"]) |
|
|
|
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config |
|
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import ( |
|
get_method_names as get_cut_method_names, |
|
) |
|
from tools.i18n.i18n import I18nAuto |
|
|
|
|
|
|
|
|
|
|
|
i18n = I18nAuto() |
|
cut_method_names = get_cut_method_names() |
|
|
|
parser = argparse.ArgumentParser(description="GPT‑SoVITS API") |
|
parser.add_argument( |
|
"-c", "--tts_config", default="GPT_SoVITS/configs/tts_infer.yaml", help="TTS‑infer config path" |
|
) |
|
parser.add_argument("-a", "--bind_addr", default="127.0.0.1", help="Bind address (default 127.0.0.1)") |
|
parser.add_argument("-p", "--port", type=int, default=9880, help="Port (default 9880)") |
|
args = parser.parse_args() |
|
|
|
config_path = args.tts_config or "GPT-SoVITS/configs/tts_infer.yaml" |
|
PORT = args.port |
|
HOST = None if args.bind_addr == "None" else args.bind_addr |
|
|
|
|
|
|
|
|
|
|
|
tts_config = TTS_Config(config_path) |
|
print(tts_config) |
|
TTS_PIPELINE = TTS(tts_config) |
|
|
|
APP = FastAPI() |
|
|
|
|
|
|
|
|
|
|
|
TEMP_DIR = os.path.join(NOW_DIR, "_tmp_audio") |
|
os.makedirs(TEMP_DIR, exist_ok=True) |
|
|
|
def _empty_cuda_cache() -> None: |
|
"""Release GPU memory if CUDA is available.""" |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def _download_to_temp(url: str) -> str: |
|
"""Download *url* to a unique file inside ``TEMP_DIR`` and return the path.""" |
|
parsed = urllib.parse.urlparse(url) |
|
filename = os.path.basename(parsed.path) or f"audio_{abs(hash(url))}.wav" |
|
local_path = os.path.join(TEMP_DIR, filename) |
|
|
|
if url.startswith("s3://"): |
|
|
|
import importlib |
|
|
|
boto3 = importlib.import_module("boto3") |
|
s3_client = boto3.client("s3") |
|
s3_client.download_file(parsed.netloc, parsed.path.lstrip("/"), local_path) |
|
else: |
|
with requests.get(url, stream=True, timeout=30) as r: |
|
r.raise_for_status() |
|
with open(local_path, "wb") as f_out: |
|
for chunk in r.iter_content(chunk_size=8192): |
|
f_out.write(chunk) |
|
|
|
return local_path |
|
|
|
|
|
def process_audio_path(audio_path: str | None) -> Tuple[str | None, bool]: |
|
"""Return a *local* path for *audio_path* and whether it is temporary.""" |
|
if not audio_path: |
|
return audio_path, False |
|
|
|
if audio_path.startswith(("http://", "https://", "s3://")): |
|
try: |
|
local = _download_to_temp(audio_path) |
|
return local, True |
|
except Exception as exc: |
|
raise HTTPException(status_code=400, detail=f"Failed to download audio: {exc}") from exc |
|
return audio_path, False |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _pack_ogg(buf: BytesIO, data: np.ndarray, rate: int): |
|
with sf.SoundFile(buf, mode="w", samplerate=rate, channels=1, format="ogg") as f: |
|
f.write(data) |
|
return buf |
|
|
|
|
|
def _pack_raw(buf: BytesIO, data: np.ndarray, _rate: int): |
|
buf.write(data.tobytes()) |
|
return buf |
|
|
|
|
|
def _pack_wav(buf: BytesIO, data: np.ndarray, rate: int): |
|
sf.write(buf, data, rate, format="wav") |
|
return buf |
|
|
|
|
|
def _pack_aac(buf: BytesIO, data: np.ndarray, rate: int): |
|
proc = subprocess.Popen( |
|
[ |
|
"ffmpeg", |
|
"-f", |
|
"s16le", |
|
"-ar", |
|
str(rate), |
|
"-ac", |
|
"1", |
|
"-i", |
|
"pipe:0", |
|
"-c:a", |
|
"aac", |
|
"-b:a", |
|
"192k", |
|
"-vn", |
|
"-f", |
|
"adts", |
|
"pipe:1", |
|
], |
|
stdin=subprocess.PIPE, |
|
stdout=subprocess.PIPE, |
|
stderr=subprocess.PIPE, |
|
) |
|
out, _ = proc.communicate(input=data.tobytes()) |
|
buf.write(out) |
|
return buf |
|
|
|
|
|
def _pack_audio(buf: BytesIO, data: np.ndarray, rate: int, media_type: str): |
|
dispatch = { |
|
"ogg": _pack_ogg, |
|
"aac": _pack_aac, |
|
"wav": _pack_wav, |
|
"raw": _pack_raw, |
|
} |
|
buf = dispatch.get(media_type, _pack_raw)(buf, data, rate) |
|
buf.seek(0) |
|
return buf |
|
|
|
|
|
|
|
|
|
|
|
|
|
class TTSRequest(BaseModel): |
|
text: str | None = None |
|
text_lang: str | None = None |
|
ref_audio_path: str | None = None |
|
aux_ref_audio_paths: List[str] | None = None |
|
prompt_lang: str | None = None |
|
prompt_text: str = "" |
|
top_k: int = 5 |
|
top_p: float = 1.0 |
|
temperature: float = 1.0 |
|
text_split_method: str = "cut5" |
|
batch_size: int = 1 |
|
batch_threshold: float = 0.75 |
|
split_bucket: bool = True |
|
speed_factor: float = 1.0 |
|
fragment_interval: float = 0.3 |
|
seed: int = -1 |
|
media_type: str = "wav" |
|
streaming_mode: bool = False |
|
parallel_infer: bool = True |
|
repetition_penalty: float = 1.35 |
|
sample_steps: int = 32 |
|
super_sampling: bool = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _validate_request(req: dict): |
|
if not req.get("text"): |
|
return "text is required" |
|
if not req.get("text_lang"): |
|
return "text_lang is required" |
|
if req["text_lang"].lower() not in tts_config.languages: |
|
return f"text_lang {req['text_lang']} not supported" |
|
if not req.get("prompt_lang"): |
|
return "prompt_lang is required" |
|
if req["prompt_lang"].lower() not in tts_config.languages: |
|
return f"prompt_lang {req['prompt_lang']} not supported" |
|
if not req.get("ref_audio_path"): |
|
return "ref_audio_path is required" |
|
mt = req.get("media_type", "wav") |
|
if mt not in {"wav", "raw", "ogg", "aac"}: |
|
return f"media_type {mt} not supported" |
|
if (not req.get("streaming_mode") and mt == "ogg"): |
|
return "ogg is only supported in streaming mode" |
|
if req.get("text_split_method", "cut5") not in cut_method_names: |
|
return f"text_split_method {req['text_split_method']} not supported" |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _tts_handle(req: dict): |
|
error = _validate_request(req) |
|
if error: |
|
return JSONResponse(status_code=400, content={"message": error}) |
|
|
|
streaming_mode = req.get("streaming_mode", False) |
|
media_type = req.get("media_type", "wav") |
|
|
|
temp_files: List[str] = [] |
|
try: |
|
|
|
ref_path, is_tmp = process_audio_path(req["ref_audio_path"]) |
|
req["ref_audio_path"] = ref_path |
|
if is_tmp: |
|
temp_files.append(ref_path) |
|
|
|
if req.get("aux_ref_audio_paths"): |
|
resolved: List[str] = [] |
|
for p in req["aux_ref_audio_paths"]: |
|
lp, tmp = process_audio_path(p) |
|
resolved.append(lp) |
|
if tmp: |
|
temp_files.append(lp) |
|
req["aux_ref_audio_paths"] = resolved |
|
|
|
|
|
generator = TTS_PIPELINE.run(req) |
|
|
|
if streaming_mode: |
|
async def _gen(gen: Generator, _media_type: str): |
|
first = True |
|
try: |
|
for sr, chunk in gen: |
|
if first and _media_type == "wav": |
|
|
|
header = _wave_header_chunk(sample_rate=sr) |
|
yield header |
|
_media_type = "raw" |
|
first = False |
|
yield _pack_audio(BytesIO(), chunk, sr, _media_type).getvalue() |
|
finally: |
|
_cleanup(temp_files) |
|
return StreamingResponse(_gen(generator, media_type), media_type=f"audio/{media_type}") |
|
|
|
|
|
sr, data = next(generator) |
|
payload = _pack_audio(BytesIO(), data, sr, media_type).getvalue() |
|
resp = Response(payload, media_type=f"audio/{media_type}") |
|
_cleanup(temp_files) |
|
return resp |
|
|
|
except Exception as exc: |
|
_cleanup(temp_files) |
|
return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(exc)}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cleanup(temp_files: List[str]): |
|
for fp in temp_files: |
|
try: |
|
os.remove(fp) |
|
|
|
except FileNotFoundError: |
|
pass |
|
except Exception as exc: |
|
print(f"[cleanup‑warning] {exc}") |
|
_empty_cuda_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
import wave |
|
|
|
def _wave_header_chunk(frame: bytes = b"", *, channels: int = 1, width: int = 2, sample_rate: int = 32_000): |
|
buf = BytesIO() |
|
with wave.open(buf, "wb") as wav: |
|
wav.setnchannels(channels) |
|
wav.setsampwidth(width) |
|
wav.setframerate(sample_rate) |
|
wav.writeframes(frame) |
|
buf.seek(0) |
|
return buf.read() |
|
|
|
|
|
|
|
|
|
|
|
|
|
@APP.get("/tts") |
|
async def tts_get(**query): |
|
|
|
for k in ("text_lang", "prompt_lang"): |
|
if k in query and query[k] is not None: |
|
query[k] = query[k].lower() |
|
return await _tts_handle(query) |
|
|
|
|
|
@APP.post("/tts") |
|
async def tts_post(request: TTSRequest): |
|
payload = request.dict() |
|
if payload.get("text_lang"): |
|
payload["text_lang"] = payload["text_lang"].lower() |
|
if payload.get("prompt_lang"): |
|
payload["prompt_lang"] = payload["prompt_lang"].lower() |
|
return await _tts_handle(payload) |
|
|
|
|
|
@APP.get("/control") |
|
async def control(command: str | None = None): |
|
if not command: |
|
raise HTTPException(status_code=400, detail="command is required") |
|
if command == "restart": |
|
os.execl(sys.executable, sys.executable, *sys.argv) |
|
elif command == "exit": |
|
os.kill(os.getpid(), signal.SIGTERM) |
|
else: |
|
raise HTTPException(status_code=400, detail="unsupported command") |
|
return {"message": "ok"} |
|
|
|
|
|
@APP.get("/set_refer_audio") |
|
async def set_refer_audio(refer_audio_path: str | None = None): |
|
if not refer_audio_path: |
|
return JSONResponse(status_code=400, content={"message": "refer_audio_path is required"}) |
|
|
|
temp_file = None |
|
try: |
|
local_path, is_tmp = process_audio_path(refer_audio_path) |
|
temp_file = local_path if is_tmp else None |
|
TTS_PIPELINE.set_ref_audio(local_path) |
|
return {"message": "success"} |
|
finally: |
|
if temp_file: |
|
try: |
|
os.remove(temp_file) |
|
except FileNotFoundError: |
|
pass |
|
_empty_cuda_cache() |
|
|
|
|
|
@APP.get("/set_gpt_weights") |
|
async def set_gpt_weights(weights_path: str | None = None): |
|
if not weights_path: |
|
return JSONResponse(status_code=400, content={"message": "gpt weight path is required"}) |
|
try: |
|
TTS_PIPELINE.init_t2s_weights(weights_path) |
|
return {"message": "success"} |
|
except Exception as exc: |
|
return JSONResponse(status_code=400, content={"message": str(exc)}) |
|
|
|
|
|
@APP.get("/set_sovits_weights") |
|
async def set_sovits_weights(weights_path: str | None = None): |
|
if not weights_path: |
|
return JSONResponse(status_code=400, content={"message": "sovits weight path is required"}) |
|
try: |
|
TTS_PIPELINE.init_vits_weights(weights_path) |
|
return {"message": "success"} |
|
except Exception as exc: |
|
return JSONResponse(status_code=400, content={"message": str(exc)}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
try: |
|
uvicorn.run(app=APP, host=HOST, port=PORT, workers=1) |
|
except Exception: |
|
traceback.print_exc() |
|
os.kill(os.getpid(), signal.SIGTERM) |
|
sys.exit(0) |
|
|