File size: 1,675 Bytes
8ec2046 8ab7180 8ec2046 544c19c 8ec2046 9b9cb9f 8ec2046 4c6d421 9b9cb9f 8ec2046 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
from typing import Dict, Any, List
import soundfile as sf
import io
import base64
from huggingface_hub import hf_hub_download
class EndpointHandler:
def __init__(self, path: str = ""):
try:
self.model = ChatterboxTTS.from_pretrained(device="cuda")
except Exception as e:
raise RuntimeError(f"[ERROR] Failed to load model: {e}")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: #, data: Dict[str, Any]) -> List[Dict[str, Any]]
try:
inputs = data.get("inputs", {})
text = inputs.get("text")
exaggeration = inputs.get("exaggeration", 0.3)
cfg_weight = inputs.get("cfg_weight", 0.5)
print(exaggeration, cfg_weight)
AUDIO_PROMPT_PATH=hf_hub_download(repo_id="aiplexdeveloper/chatterbox", filename="arjun_das_output_audio.mp3")
wav = self.model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH, exaggeration = exaggeration, cfg_weight=cfg_weight)
buffer = io.BytesIO()
sf.write(buffer, wav.cpu().numpy().T, self.model.sr, format='WAV')
buffer.seek(0)
# Encode to base64
audio_base64 = base64.b64encode(buffer.read()).decode('utf-8')
wav_squeeze = wav.squeeze() # Shape becomes [960000]
audio_length_seconds = len(wav_squeeze) / self.model.sr
return [{"audio_base64": audio_base64, "audio_length_seconds":audio_length_seconds}]
except Exception as e:
print(f"[ERROR] Inference failed: {e}")
return [{"error": str(e)}]
|