File size: 7,910 Bytes
0bb0b13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd175ab
 
ac04834
0bb0b13
 
 
 
ac04834
 
 
 
 
0bb0b13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import gradio as gr
import httpx
import os
import atexit
from loguru import logger

# FastAPI endpoint URL - adjust this to match your actual endpoint
API_URL = os.getenv("API_URL").rstrip('/')

# Configure httpx client with retries and timeouts
client = httpx.Client(
    timeout=httpx.Timeout(
        connect=10.0,    # connection timeout
        read=120.0,      # read timeout
        write=10.0,      # write timeout
        pool=None,       # pool timeout
    ),
    limits=httpx.Limits(
        max_keepalive_connections=5,
        max_connections=10,
        keepalive_expiry=30.0
    ),
    transport=httpx.HTTPTransport(
        retries=3,  # Number of retries
    )
)

def check_api_health():
    """Check if the API is healthy before making requests"""
    try:
        response = client.get(f"{API_URL}/")
        response.raise_for_status()
        logger.info("API health check passed")
        return True
    except httpx.TimeoutException as e:
        logger.error(f"API health check timed out: {str(e)}")
        return False
    except httpx.HTTPError as e:
        logger.error(f"API health check failed: {str(e)}")
        return False

def generate_speech(text, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
    if not text.strip():
        logger.warning("Empty text input received")
        return None
    
    try:
        # Check API health first
        if not check_api_health():
            logger.error("API is not healthy, aborting request")
            raise gr.Error("The API service is currently unavailable. Please try again later.")

        # Log input parameters
        logger.info(f"Generating speech for text: {text[:50]}... with params: temp={temperature}, top_p={top_p}, rep_penalty={repetition_penalty}, max_tokens={max_new_tokens}")
        
        # Prepare the request payload
        payload = {
            "text": text.strip(),
            "return_type": "wav",  # Request WAV format directly
            "temperature": temperature,
            "top_p": top_p,
            "repetition_penalty": repetition_penalty,
            "max_new_tokens": max_new_tokens
        }
        
        # Update progress
        progress(0.3, "Sending request to server ...")
        
        # Make request to FastAPI endpoint
        response = client.post(
            f"{API_URL}/tts",
            json=payload,
            headers={"Content-Type": "application/json"}
        )
        
        # Log response status
        logger.debug(f"Received response with status {response.status_code} and content-type {response.headers.get('content-type')}")
        
        # Return the WAV bytes directly
        if response.status_code == 200:
            logger.info("Successfully generated speech in WAV format")
            return response.content
        else:
            error_msg = f"API returned error status {response.status_code}"
            logger.error(error_msg)
            raise gr.Error(error_msg)
    
    except httpx.TimeoutException as e:
        error_msg = "Request timed out. The server took too long to respond."
        logger.error(f"{error_msg}: {str(e)}")
        raise gr.Error(error_msg)
    except httpx.HTTPError as e:
        error_msg = f"Network error while generating speech: {str(e)}"
        logger.error(error_msg)
        raise gr.Error(error_msg)
    except Exception as e:
        error_msg = f"Error generating speech: {str(e)}"
        logger.error(error_msg, exc_info=True)
        raise gr.Error(error_msg)

# Clean up client on exit
atexit.register(client.close)

# Examples for the UI
examples = [
    [
        "Салом, номи ман Али аст ва ман имрӯз мехоҳам ба шумо дар бораи забони тоҷикӣ ва аҳамияти он дар фарҳанги мо нақл кунам.",
        0.6, 0.95, 1.1, 1800
    ],
    [
        "Имрӯз ҳаво хеле хуб аст ва ман қарор додам, ки бо дӯстонам ба боғ равам ва якҷоя вақт гузаронем.",
        0.6, 0.95, 1.1, 1200
    ],
    [
        "Ман забони тоҷикӣ меомӯзам, зеро мехоҳам бо мардумони гуногун сӯҳбат кунам ва фарҳанги онҳоро беҳтар фаҳмам.",
        0.6, 0.95, 1.1, 1200
    ],
    [
        "Лутфан як пиёла чой диҳед, зеро ман имрӯз хеле хаста шудам ва мехоҳам каме истироҳат кунам.",
        0.6, 0.95, 1.1, 1200
    ],
    [
        "Шумо аз куҷо ҳастед ва чӣ гуна ба омӯзиши забони тоҷикӣ шурӯъ кардед?",
        0.6, 0.95, 1.1, 1200
    ],
]

# Create Gradio interface
with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
    gr.Markdown("""
    # 🎵 [Tajik Orpheus Text-to-Speech](https://huggingface.co/re-skill/orpheus-tj-early)
    
    Enter your text below and hear it converted to natural-sounding speech with the Orpheus TTS model. 
    
    ## Tips for better prompts:
    - Short text prompts generally work better than very long phrases
    - Increasing `repetition_penalty` and `temperature` makes the model speak faster.

    ## Note:
    - This is demo of early checkpoint trained only on `35 Hours` of data.
    - The model was not fine-tuned on a specific voice. Hence, you will get different voices every time you run the model.
    
    """)    
    with gr.Row():
        with gr.Column(scale=3):
            text_input = gr.Textbox(
                label="Text to speak", 
                placeholder="Enter your text here...",
                lines=5
            )
            
            with gr.Accordion("Advanced Settings", open=False):
                temperature = gr.Slider(
                    minimum=0.1, maximum=1.5, value=0.6, step=0.05,
                    label="Temperature", 
                    info="Higher values (0.7-1.0) create more expressive but less stable speech"
                )
                top_p = gr.Slider(
                    minimum=0.1, maximum=1.0, value=0.95, step=0.05,
                    label="Top P", 
                    info="Nucleus sampling threshold"
                )
                repetition_penalty = gr.Slider(
                    minimum=1.0, maximum=2.0, value=1.1, step=0.05,
                    label="Repetition Penalty", 
                    info="Higher values discourage repetitive patterns"
                )
                max_new_tokens = gr.Slider(
                    minimum=100, maximum=2000, value=1200, step=100,
                    label="Max Length", 
                    info="Maximum length of generated audio (in tokens)"
                )
            
            with gr.Row():
                submit_btn = gr.Button("Generate Speech", variant="primary")
                clear_btn = gr.Button("Clear")
                
        with gr.Column(scale=2):
            # Audio component that can handle WAV bytes
            audio_output = gr.Audio(
                label="Generated Speech",
                type="filepath"  # Changed from "auto" to "filepath" to handle WAV bytes
            )
            
    # Set up examples
    gr.Examples(
        examples=examples,
        inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens],
        outputs=audio_output,
        fn=generate_speech,
        cache_examples=False,
    )
    
    # Set up event handlers
    submit_btn.click(
        fn=generate_speech,
        inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens],
        outputs=audio_output
    )
    
    clear_btn.click(
        fn=lambda: (None, None),
        inputs=[],
        outputs=[text_input, audio_output]
    )

# Launch the app
if __name__ == "__main__":
    demo.queue().launch(share=False, ssr_mode=False)