chatterbox / handler.py
aiplexdeveloper's picture
Update handler.py
f969673 verified
raw
history blame
1.59 kB
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
from typing import Dict, Any, List
import soundfile as sf
import io
import base64
class EndpointHandler:
def __init__(self, path: str = ""):
try:
self.model = ChatterboxTTS.from_pretrained(device="cuda")
except Exception as e:
raise RuntimeError(f"[ERROR] Failed to load model: {e}")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: #, data: Dict[str, Any]) -> List[Dict[str, Any]]
try:
inputs = data.get("inputs", {})
text = inputs.get("text")
exaggeration = inputs.get("exaggeration", 0.3)
cfg_weight = inputs.get("cfg_weight", 0.5)
print(exaggeration, cfg_weight)
AUDIO_PROMPT_PATH="https://huggingface.co/aiplexdeveloper/chatterbox/resolve/main/arjun_das_output_audio.mp3"
wav = self.model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH, exaggeration = exaggeration, cfg_weight=cfg_weight)
# ta.save("test-2.wav", wav, self.model.sr)
# Convert the tensor to numpy and write to an in-memory buffer
buffer = io.BytesIO()
sf.write(buffer, wav.cpu().numpy().T, self.model.sr, format='WAV')
buffer.seek(0)
# Encode to base64
audio_base64 = base64.b64encode(buffer.read()).decode('utf-8')
return [{"audio_base64": audio_base64}]
except Exception as e:
print(f"[ERROR] Inference failed: {e}")
return [{"error": str(e)}]