Spaces:
Runtime error
Runtime error
| from fastapi import File, Form, HTTPException, Body, UploadFile | |
| from fastapi.responses import StreamingResponse | |
| import io | |
| from numpy import clip | |
| import soundfile as sf | |
| from pydantic import BaseModel, Field | |
| from fastapi.responses import FileResponse | |
| from modules.synthesize_audio import synthesize_audio | |
| from modules.normalization import text_normalize | |
| from modules import generate_audio as generate | |
| from typing import List, Literal, Optional, Union | |
| import pyrubberband as pyrb | |
| from modules.api import utils as api_utils | |
| from modules.api.Api import APIManager | |
| from modules.speaker import speaker_mgr | |
| from modules.data import styles_mgr | |
| import numpy as np | |
| class AudioSpeechRequest(BaseModel): | |
| input: str # 需要合成的文本 | |
| model: str = "chattts-4w" | |
| voice: str = "female2" | |
| response_format: Literal["mp3", "wav"] = "mp3" | |
| speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio") | |
| seed: int = 42 | |
| temperature: float = 0.3 | |
| style: str = "" | |
| # 是否开启batch合成,小于等于1表示不适用batch | |
| # 开启batch合成会自动分割句子 | |
| batch_size: int = Field(1, ge=1, le=20, description="Batch size") | |
| spliter_threshold: float = Field( | |
| 100, ge=10, le=1024, description="Threshold for sentence spliter" | |
| ) | |
| async def openai_speech_api( | |
| request: AudioSpeechRequest = Body( | |
| ..., description="JSON body with model, input text, and voice" | |
| ) | |
| ): | |
| model = request.model | |
| input_text = request.input | |
| voice = request.voice | |
| style = request.style | |
| response_format = request.response_format | |
| batch_size = request.batch_size | |
| spliter_threshold = request.spliter_threshold | |
| speed = request.speed | |
| speed = clip(speed, 0.1, 10) | |
| if not input_text: | |
| raise HTTPException(status_code=400, detail="Input text is required.") | |
| if speaker_mgr.get_speaker(voice) is None: | |
| raise HTTPException(status_code=400, detail="Invalid voice.") | |
| try: | |
| if style: | |
| styles_mgr.find_item_by_name(style) | |
| except: | |
| raise HTTPException(status_code=400, detail="Invalid style.") | |
| try: | |
| # Normalize the text | |
| text = text_normalize(input_text, is_end=True) | |
| # Calculate speaker and style based on input voice | |
| params = api_utils.calc_spk_style(spk=voice, style=style) | |
| spk = params.get("spk", -1) | |
| seed = params.get("seed", request.seed or 42) | |
| temperature = params.get("temperature", request.temperature or 0.3) | |
| prompt1 = params.get("prompt1", "") | |
| prompt2 = params.get("prompt2", "") | |
| prefix = params.get("prefix", "") | |
| # Generate audio | |
| sample_rate, audio_data = synthesize_audio( | |
| text, | |
| temperature=temperature, | |
| top_P=0.7, | |
| top_K=20, | |
| spk=spk, | |
| infer_seed=seed, | |
| batch_size=batch_size, | |
| spliter_threshold=spliter_threshold, | |
| prompt1=prompt1, | |
| prompt2=prompt2, | |
| prefix=prefix, | |
| ) | |
| if speed != 1: | |
| audio_data = pyrb.time_stretch(audio_data, sample_rate, speed) | |
| # Convert audio data to wav format | |
| buffer = io.BytesIO() | |
| sf.write(buffer, audio_data, sample_rate, format="wav") | |
| buffer.seek(0) | |
| if response_format == "mp3": | |
| # Convert wav to mp3 | |
| buffer = api_utils.wav_to_mp3(buffer) | |
| return StreamingResponse(buffer, media_type="audio/mp3") | |
| except Exception as e: | |
| import logging | |
| logging.exception(e) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| class TranscribeSegment(BaseModel): | |
| id: int | |
| seek: float | |
| start: float | |
| end: float | |
| text: str | |
| tokens: list[int] | |
| temperature: float | |
| avg_logprob: float | |
| compression_ratio: float | |
| no_speech_prob: float | |
| class TranscriptionsVerboseResponse(BaseModel): | |
| task: str | |
| language: str | |
| duration: float | |
| text: str | |
| segments: list[TranscribeSegment] | |
| def setup(app: APIManager): | |
| app.post( | |
| "/v1/audio/speech", | |
| response_class=FileResponse, | |
| description=""" | |
| openai api document: | |
| [https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech) | |
| 以下属性为本系统自定义属性,不在openai文档中: | |
| - batch_size: 是否开启batch合成,小于等于1表示不使用batch (不推荐) | |
| - spliter_threshold: 开启batch合成时,句子分割的阈值 | |
| - style: 风格 | |
| > model 可填任意值 | |
| """, | |
| )(openai_speech_api) | |
| async def transcribe( | |
| file: UploadFile = File(...), | |
| model: str = Form(...), | |
| language: Optional[str] = Form(None), | |
| prompt: Optional[str] = Form(None), | |
| response_format: str = Form("json"), | |
| temperature: float = Form(0), | |
| timestamp_granularities: List[str] = Form(["segment"]), | |
| ): | |
| # TODO: Implement transcribe | |
| return api_utils.success_response("not implemented yet") | |