Spaces:
Running
Running
import os | |
from dotenv import load_dotenv | |
import random | |
from gradio_client import Client, handle_file,file | |
from huggingface_hub.constants import HF_TOKEN_PATH | |
from pydub import AudioSegment | |
import os.path | |
load_dotenv() | |
ZEROGPU_TOKENS = os.getenv("ZEROGPU_TOKENS", "").split(",") | |
def get_zerogpu_token(): | |
if not ZEROGPU_TOKENS or ZEROGPU_TOKENS == [""]: | |
return os.getenv("HF_TOKEN") | |
return random.choice(ZEROGPU_TOKENS) | |
model_mapping = { | |
"spark-tts": { | |
"provider": "spark", | |
"model": "spark-tts", | |
}, | |
"cosyvoice-2.0": { | |
"provider": "cosyvoice", | |
"model": "cosyvoice_2_0", | |
}, | |
"index-tts": { | |
"provider": "bilibili", | |
"model": "index-tts", | |
}, | |
"maskgct": { | |
"provider": "amphion", | |
"model": "maskgct", | |
}, | |
"gpt-sovits-v2-pro-plus": { | |
"provider": "gpt-sovits", | |
"model": "gpt-sovits-v2-pro-plus", | |
}, | |
} | |
url = "https://tts-agi-tts-router-v2.hf.space/tts" | |
headers = { | |
"accept": "application/json", | |
"Content-Type": "application/json", | |
"Authorization": f'Bearer {os.getenv("HF_TOKEN")}', | |
} | |
data = {"text": "string", "provider": "string", "model": "string"} | |
def set_client_for_session(space:str, user_token=None): | |
if user_token is None: | |
return Client(space, hf_token=get_zerogpu_token()) | |
else: | |
x_ip_token = user_token | |
return Client(space, headers={"X-IP-Token": x_ip_token}) | |
def predict_index_tts(text, user_token=None, reference_audio_path=None): | |
client = set_client_for_session("kemuriririn/IndexTTS",user_token=user_token) | |
if reference_audio_path: | |
prompt = handle_file(reference_audio_path) | |
else: | |
raise ValueError("index-tts ιθ¦ reference_audio_path") | |
result = client.predict( | |
prompt=prompt, | |
text=text, | |
api_name="/gen_single" | |
) | |
if type(result) != str: | |
result = result.get("value") | |
print("index-tts result:", result) | |
return result | |
def predict_spark_tts(text, user_token=None,reference_audio_path=None): | |
client = set_client_for_session("thunnai/SparkTTS",user_token=user_token) | |
prompt_wav = None | |
if reference_audio_path: | |
prompt_wav = handle_file(reference_audio_path) | |
result = client.predict( | |
text=text, | |
prompt_text=text, | |
prompt_wav_upload=prompt_wav, | |
prompt_wav_record=prompt_wav, | |
api_name="/voice_clone" | |
) | |
print("spark-tts result:", result) | |
return result | |
def predict_cosyvoice_tts(text, user_token=None, reference_audio_path=None): | |
client = set_client_for_session("kemuriririn/CosyVoice2-0.5B",user_token=user_token) | |
if not reference_audio_path: | |
raise ValueError("cosyvoice-2.0 ιθ¦ reference_audio_path") | |
prompt_wav = handle_file(reference_audio_path) | |
# ε θ―ε«εθι³ι’ζζ¬ | |
recog_result = client.predict( | |
prompt_wav=file(reference_audio_path), | |
api_name="/prompt_wav_recognition" | |
) | |
print("cosyvoice-2.0 prompt_wav_recognition result:", recog_result) | |
prompt_text = recog_result if isinstance(recog_result, str) else str(recog_result) | |
result = client.predict( | |
tts_text=text, | |
prompt_text=prompt_text, | |
prompt_wav_upload=prompt_wav, | |
prompt_wav_record=prompt_wav, | |
seed=0, | |
stream=False, | |
api_name="/generate_audio" | |
) | |
print("cosyvoice-2.0 result:", result) | |
return result | |
def predict_maskgct(text, user_token=None, reference_audio_path=None): | |
client = set_client_for_session("amphion/maskgct",user_token=user_token) | |
if not reference_audio_path: | |
raise ValueError("maskgct ιθ¦ reference_audio_path") | |
prompt_wav = handle_file(reference_audio_path) | |
result = client.predict( | |
prompt_wav=prompt_wav, | |
target_text=text, | |
target_len=-1, | |
n_timesteps=25, | |
api_name="/predict" | |
) | |
print("maskgct result:", result) | |
return result | |
def predict_gpt_sovits_v2(text, user_token=None,reference_audio_path=None): | |
client = set_client_for_session("kemuriririn/GPT-SoVITS-v2",user_token=user_token) | |
if not reference_audio_path: | |
raise ValueError("GPT-SoVITS-v2 ιθ¦ reference_audio_path") | |
result = client.predict( | |
ref_wav_path=file(reference_audio_path), | |
prompt_text="", | |
prompt_language="English", | |
text=text, | |
text_language="English", | |
how_to_cut="Slice once every 4 sentences", | |
top_k=15, | |
top_p=1, | |
temperature=1, | |
ref_free=False, | |
speed=1, | |
if_freeze=False, | |
inp_refs=[], | |
api_name="/get_tts_wav" | |
) | |
print("gpt-sovits-v2 result:", result) | |
return result | |
def normalize_audio_volume(audio_path): | |
"""ζε€§ει³ι’ι³ι""" | |
# θ·εζδ»Άζ©ε±ε | |
file_name, ext = os.path.splitext(audio_path) | |
normalized_path = f"{file_name}_normalized{ext}" | |
# θ―»ει³ι’ζδ»Ά | |
sound = AudioSegment.from_file(audio_path) | |
# ζε€§ει³ι (ζ εε) | |
normalized_sound = sound.normalize() | |
# δΏεε€ηεηι³ι’ | |
normalized_sound.export(normalized_path, format=ext.replace('.', '')) | |
return normalized_path | |
def predict_tts(text, model, user_token=None, reference_audio_path=None): | |
print(f"Predicting TTS for {model}, user_token: {user_token}, reference_audio_path: {reference_audio_path}") | |
# Exceptions: special models that shouldn't be passed to the router | |
if model == "index-tts": | |
result = predict_index_tts(text, user_token,reference_audio_path) | |
elif model == "spark-tts": | |
result = predict_spark_tts(text, user_token,reference_audio_path) | |
elif model == "cosyvoice-2.0": | |
result = predict_cosyvoice_tts(text, user_token,reference_audio_path) | |
elif model == "maskgct": | |
result = predict_maskgct(text, user_token,reference_audio_path) | |
elif model == "gpt-sovits-v2-pro-plus": | |
result = predict_gpt_sovits_v2(text, user_token, reference_audio_path) | |
else: | |
raise ValueError(f"Model {model} not found") | |
# ε―Ήηζηι³ι’θΏθ‘ι³ιζε€§εε€η | |
normalized_result = normalize_audio_volume(result) | |
return normalized_result | |
if __name__ == "__main__": | |
pass |