File size: 1,675 Bytes
8ec2046
 
 
 
 
 
8ab7180
8ec2046
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544c19c
8ec2046
9b9cb9f
8ec2046
 
 
 
 
 
 
4c6d421
 
9b9cb9f
 
 
8ec2046
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
from typing import Dict, Any, List
import soundfile as sf
import io
import base64
from huggingface_hub import hf_hub_download


class EndpointHandler:
    def __init__(self, path: str = ""):
        try:
            self.model = ChatterboxTTS.from_pretrained(device="cuda")
        except Exception as e:
            raise RuntimeError(f"[ERROR] Failed to load model: {e}")

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: #, data: Dict[str, Any]) -> List[Dict[str, Any]]
        try:
            inputs = data.get("inputs", {})
            text = inputs.get("text")
            exaggeration = inputs.get("exaggeration", 0.3)
            cfg_weight = inputs.get("cfg_weight", 0.5)
            print(exaggeration, cfg_weight)

            AUDIO_PROMPT_PATH=hf_hub_download(repo_id="aiplexdeveloper/chatterbox", filename="arjun_das_output_audio.mp3")
            wav = self.model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH, exaggeration = exaggeration, cfg_weight=cfg_weight)

            buffer = io.BytesIO()
            sf.write(buffer, wav.cpu().numpy().T, self.model.sr, format='WAV')
            buffer.seek(0)

            # Encode to base64
            audio_base64 = base64.b64encode(buffer.read()).decode('utf-8')

            wav_squeeze = wav.squeeze()  # Shape becomes [960000]
            audio_length_seconds = len(wav_squeeze) / self.model.sr


            return [{"audio_base64": audio_base64, "audio_length_seconds":audio_length_seconds}]


        except Exception as e:
            print(f"[ERROR] Inference failed: {e}")
            return [{"error": str(e)}]