|
import torchaudio as ta |
|
from chatterbox.tts import ChatterboxTTS |
|
from typing import Dict, Any, List |
|
import soundfile as sf |
|
import io |
|
import base64 |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path: str = ""): |
|
try: |
|
self.model = ChatterboxTTS.from_pretrained(device="cuda") |
|
except Exception as e: |
|
raise RuntimeError(f"[ERROR] Failed to load model: {e}") |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
try: |
|
inputs = data.get("inputs", {}) |
|
text = inputs.get("text") |
|
exaggeration = inputs.get("exaggeration", 0.3) |
|
cfg_weight = inputs.get("cfg_weight", 0.5) |
|
print(exaggeration, cfg_weight) |
|
|
|
AUDIO_PROMPT_PATH=hf_hub_download(repo_id="aiplexdeveloper/chatterbox", filename="arjun_das_output_audio.mp3") |
|
wav = self.model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH, exaggeration = exaggeration, cfg_weight=cfg_weight) |
|
|
|
buffer = io.BytesIO() |
|
sf.write(buffer, wav.cpu().numpy().T, self.model.sr, format='WAV') |
|
buffer.seek(0) |
|
|
|
|
|
audio_base64 = base64.b64encode(buffer.read()).decode('utf-8') |
|
|
|
wav_squeeze = wav.squeeze() |
|
audio_length_seconds = len(wav_squeeze) / self.model.sr |
|
|
|
|
|
return [{"audio_base64": audio_base64, "audio_length_seconds":audio_length_seconds}] |
|
|
|
|
|
except Exception as e: |
|
print(f"[ERROR] Inference failed: {e}") |
|
return [{"error": str(e)}] |
|
|