ceymox commited on
Commit
e1ff5d6
Β·
verified Β·
1 Parent(s): 97ce168

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +483 -0
app.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import tempfile
7
+ import uuid
8
+ import logging
9
+ from typing import Optional, Dict, Any
10
+ from pathlib import Path
11
+
12
+ import gradio as gr
13
+ import spaces
14
+ from fastapi import FastAPI, HTTPException
15
+ from fastapi.responses import StreamingResponse
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from pydantic import BaseModel
18
+
19
+ # Import ChatterboxTTS
20
+ from chatterbox.src.chatterbox.tts import ChatterboxTTS
21
+
22
+ # Configure logging
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Device configuration
27
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
+ logger.info(f"πŸš€ Running on device: {DEVICE}")
29
+
30
+ # Global model variable
31
+ MODEL = None
32
+
33
+ # Storage for generated audio
34
+ AUDIO_DIR = "generated_audio"
35
+ os.makedirs(AUDIO_DIR, exist_ok=True)
36
+ audio_cache = {}
37
+
38
+ def get_or_load_model():
39
+ """Load ChatterboxTTS model if not already loaded"""
40
+ global MODEL
41
+ if MODEL is None:
42
+ logger.info("Loading ChatterboxTTS model...")
43
+ try:
44
+ MODEL = ChatterboxTTS.from_pretrained(DEVICE)
45
+ if hasattr(MODEL, 'to'):
46
+ MODEL.to(DEVICE)
47
+ logger.info("βœ… ChatterboxTTS model loaded successfully")
48
+ except Exception as e:
49
+ logger.error(f"❌ Error loading model: {e}")
50
+ raise
51
+ return MODEL
52
+
53
+ def set_seed(seed: int):
54
+ """Set random seed for reproducibility"""
55
+ torch.manual_seed(seed)
56
+ if DEVICE == "cuda":
57
+ torch.cuda.manual_seed(seed)
58
+ torch.cuda.manual_seed_all(seed)
59
+ np.random.seed(seed)
60
+
61
+ def generate_id():
62
+ """Generate unique ID"""
63
+ return str(uuid.uuid4())
64
+
65
+ # Pydantic models for API
66
+ class TTSRequest(BaseModel):
67
+ text: str
68
+ audio_prompt_url: Optional[str] = "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
69
+ exaggeration: Optional[float] = 0.5
70
+ temperature: Optional[float] = 0.8
71
+ cfg_weight: Optional[float] = 0.5
72
+ seed: Optional[int] = 0
73
+
74
+ class TTSResponse(BaseModel):
75
+ success: bool
76
+ audio_id: Optional[str] = None
77
+ message: str
78
+ sample_rate: Optional[int] = None
79
+ duration: Optional[float] = None
80
+
81
+ # Load model at startup
82
+ try:
83
+ get_or_load_model()
84
+ except Exception as e:
85
+ logger.error(f"Failed to load model on startup: {e}")
86
+
87
+ @spaces.GPU
88
+ def generate_tts_audio(
89
+ text_input: str,
90
+ audio_prompt_path_input: str,
91
+ exaggeration_input: float,
92
+ temperature_input: float,
93
+ seed_num_input: int,
94
+ cfgw_input: float
95
+ ) -> tuple[int, np.ndarray]:
96
+ """
97
+ Generate TTS audio using ChatterboxTTS model
98
+ """
99
+ current_model = get_or_load_model()
100
+
101
+ if current_model is None:
102
+ raise RuntimeError("TTS model is not loaded")
103
+
104
+ if seed_num_input != 0:
105
+ set_seed(int(seed_num_input))
106
+
107
+ logger.info(f"🎡 Generating audio for: '{text_input[:50]}...'")
108
+
109
+ try:
110
+ wav = current_model.generate(
111
+ text_input[:300], # Limit text length
112
+ audio_prompt_path=audio_prompt_path_input,
113
+ exaggeration=exaggeration_input,
114
+ temperature=temperature_input,
115
+ cfg_weight=cfgw_input,
116
+ )
117
+
118
+ logger.info("βœ… Audio generation complete")
119
+ return (current_model.sr, wav.squeeze(0).numpy())
120
+
121
+ except Exception as e:
122
+ logger.error(f"❌ Audio generation failed: {e}")
123
+ raise
124
+
125
+ # FastAPI app for API endpoints
126
+ app = FastAPI(
127
+ title="ChatterboxTTS API",
128
+ description="High-quality text-to-speech synthesis using ChatterboxTTS",
129
+ version="1.0.0"
130
+ )
131
+
132
+ app.add_middleware(
133
+ CORSMiddleware,
134
+ allow_origins=["*"],
135
+ allow_credentials=True,
136
+ allow_methods=["*"],
137
+ allow_headers=["*"],
138
+ )
139
+
140
+ @app.get("/")
141
+ async def root():
142
+ """API status endpoint"""
143
+ return {
144
+ "service": "ChatterboxTTS API",
145
+ "version": "1.0.0",
146
+ "status": "operational" if MODEL else "model_loading",
147
+ "model_loaded": MODEL is not None,
148
+ "device": DEVICE,
149
+ "endpoints": {
150
+ "synthesize": "/api/tts/synthesize",
151
+ "audio": "/api/audio/{audio_id}",
152
+ "health": "/health"
153
+ }
154
+ }
155
+
156
+ @app.get("/health")
157
+ async def health_check():
158
+ """Health check endpoint"""
159
+ return {
160
+ "status": "healthy" if MODEL else "unhealthy",
161
+ "model_loaded": MODEL is not None,
162
+ "device": DEVICE,
163
+ "timestamp": time.time()
164
+ }
165
+
166
+ @app.post("/api/tts/synthesize", response_model=TTSResponse)
167
+ async def synthesize_speech(request: TTSRequest):
168
+ """
169
+ Synthesize speech from text
170
+ """
171
+ try:
172
+ if MODEL is None:
173
+ raise HTTPException(status_code=503, detail="Model not loaded")
174
+
175
+ if not request.text.strip():
176
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
177
+
178
+ if len(request.text) > 500:
179
+ raise HTTPException(status_code=400, detail="Text too long (max 500 characters)")
180
+
181
+ start_time = time.time()
182
+
183
+ # Generate audio
184
+ sample_rate, audio_data = generate_tts_audio(
185
+ request.text,
186
+ request.audio_prompt_url,
187
+ request.exaggeration,
188
+ request.temperature,
189
+ request.seed,
190
+ request.cfg_weight
191
+ )
192
+
193
+ generation_time = time.time() - start_time
194
+
195
+ # Save audio file
196
+ audio_id = generate_id()
197
+ audio_path = os.path.join(AUDIO_DIR, f"{audio_id}.wav")
198
+ sf.write(audio_path, audio_data, sample_rate)
199
+
200
+ # Cache audio info
201
+ audio_cache[audio_id] = {
202
+ "path": audio_path,
203
+ "text": request.text,
204
+ "sample_rate": sample_rate,
205
+ "duration": len(audio_data) / sample_rate,
206
+ "generated_at": time.time(),
207
+ "generation_time": generation_time
208
+ }
209
+
210
+ logger.info(f"βœ… Audio saved: {audio_id} ({generation_time:.2f}s)")
211
+
212
+ return TTSResponse(
213
+ success=True,
214
+ audio_id=audio_id,
215
+ message="Speech synthesized successfully",
216
+ sample_rate=sample_rate,
217
+ duration=len(audio_data) / sample_rate
218
+ )
219
+
220
+ except HTTPException:
221
+ raise
222
+ except Exception as e:
223
+ logger.error(f"❌ Synthesis failed: {e}")
224
+ raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
225
+
226
+ @app.get("/api/audio/{audio_id}")
227
+ async def get_audio(audio_id: str):
228
+ """
229
+ Download generated audio file
230
+ """
231
+ if audio_id not in audio_cache:
232
+ raise HTTPException(status_code=404, detail="Audio not found")
233
+
234
+ audio_info = audio_cache[audio_id]
235
+ audio_path = audio_info["path"]
236
+
237
+ if not os.path.exists(audio_path):
238
+ raise HTTPException(status_code=404, detail="Audio file not found on disk")
239
+
240
+ def iterfile():
241
+ with open(audio_path, "rb") as f:
242
+ yield from f
243
+
244
+ return StreamingResponse(
245
+ iterfile(),
246
+ media_type="audio/wav",
247
+ headers={
248
+ "Content-Disposition": f"attachment; filename=tts_{audio_id}.wav"
249
+ }
250
+ )
251
+
252
+ @app.get("/api/audio/{audio_id}/info")
253
+ async def get_audio_info(audio_id: str):
254
+ """
255
+ Get audio file information
256
+ """
257
+ if audio_id not in audio_cache:
258
+ raise HTTPException(status_code=404, detail="Audio not found")
259
+
260
+ return audio_cache[audio_id]
261
+
262
+ @app.get("/api/audio")
263
+ async def list_audio():
264
+ """
265
+ List all generated audio files
266
+ """
267
+ return {
268
+ "audio_files": [
269
+ {
270
+ "audio_id": audio_id,
271
+ "text": info["text"][:50] + "..." if len(info["text"]) > 50 else info["text"],
272
+ "duration": info["duration"],
273
+ "generated_at": info["generated_at"]
274
+ }
275
+ for audio_id, info in audio_cache.items()
276
+ ],
277
+ "total": len(audio_cache)
278
+ }
279
+
280
+ # Gradio interface
281
+ def create_gradio_interface():
282
+ """Create simple Gradio interface"""
283
+
284
+ with gr.Blocks(title="ChatterboxTTS", theme=gr.themes.Soft()) as demo:
285
+ gr.Markdown("""
286
+ # 🎡 ChatterboxTTS
287
+
288
+ High-quality text-to-speech synthesis with voice cloning capabilities.
289
+ """)
290
+
291
+ with gr.Row():
292
+ with gr.Column():
293
+ text_input = gr.Textbox(
294
+ value="Hello, this is ChatterboxTTS. I can generate natural-sounding speech from any text you provide.",
295
+ label="Text to synthesize (max 300 characters)",
296
+ max_lines=5,
297
+ placeholder="Enter your text here..."
298
+ )
299
+
300
+ audio_prompt = gr.Textbox(
301
+ value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac",
302
+ label="Reference Audio URL",
303
+ placeholder="URL to reference audio file"
304
+ )
305
+
306
+ with gr.Row():
307
+ exaggeration = gr.Slider(
308
+ 0.25, 2,
309
+ step=0.05,
310
+ label="Exaggeration",
311
+ value=0.5,
312
+ info="Controls expressiveness (0.5 = neutral)"
313
+ )
314
+
315
+ cfg_weight = gr.Slider(
316
+ 0.2, 1,
317
+ step=0.05,
318
+ label="CFG Weight",
319
+ value=0.5,
320
+ info="Controls pace and clarity"
321
+ )
322
+
323
+ with gr.Accordion("Advanced Settings", open=False):
324
+ temperature = gr.Slider(
325
+ 0.05, 5,
326
+ step=0.05,
327
+ label="Temperature",
328
+ value=0.8,
329
+ info="Controls randomness"
330
+ )
331
+
332
+ seed = gr.Number(
333
+ value=0,
334
+ label="Seed (0 = random)",
335
+ info="Set to non-zero for reproducible results"
336
+ )
337
+
338
+ generate_btn = gr.Button("🎡 Generate Speech", variant="primary")
339
+
340
+ with gr.Column():
341
+ audio_output = gr.Audio(label="Generated Speech")
342
+
343
+ status_text = gr.Textbox(
344
+ label="Status",
345
+ interactive=False,
346
+ placeholder="Click 'Generate Speech' to start..."
347
+ )
348
+
349
+ # Examples
350
+ gr.Examples(
351
+ examples=[
352
+ [
353
+ "Welcome to our podcast! Today we're discussing the latest developments in artificial intelligence.",
354
+ "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac",
355
+ 0.6, 0.8, 0, 0.5
356
+ ],
357
+ [
358
+ "Good morning! I hope you're having a wonderful day. Let me tell you about our exciting new features.",
359
+ "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac",
360
+ 0.7, 0.9, 0, 0.6
361
+ ],
362
+ [
363
+ "In today's tutorial, we'll learn how to build a machine learning model from scratch using Python.",
364
+ "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac",
365
+ 0.4, 0.7, 0, 0.4
366
+ ]
367
+ ],
368
+ inputs=[text_input, audio_prompt, exaggeration, temperature, seed, cfg_weight]
369
+ )
370
+
371
+ def generate_speech_ui(text, prompt_url, exag, temp, seed_val, cfg):
372
+ """Generate speech from UI"""
373
+ try:
374
+ if not text.strip():
375
+ return None, "❌ Please enter some text"
376
+
377
+ if len(text) > 300:
378
+ return None, "❌ Text too long (max 300 characters)"
379
+
380
+ start_time = time.time()
381
+
382
+ # Generate audio
383
+ sample_rate, audio_data = generate_tts_audio(
384
+ text, prompt_url, exag, temp, int(seed_val), cfg
385
+ )
386
+
387
+ generation_time = time.time() - start_time
388
+ duration = len(audio_data) / sample_rate
389
+
390
+ status = f"""βœ… Speech generated successfully!
391
+
392
+ ⏱️ Generation time: {generation_time:.2f}s
393
+ 🎡 Audio duration: {duration:.2f}s
394
+ πŸ“Š Sample rate: {sample_rate} Hz
395
+ πŸ”Š Audio samples: {len(audio_data):,}
396
+ """
397
+
398
+ return (sample_rate, audio_data), status
399
+
400
+ except Exception as e:
401
+ logger.error(f"UI generation failed: {e}")
402
+ return None, f"❌ Generation failed: {str(e)}"
403
+
404
+ generate_btn.click(
405
+ fn=generate_speech_ui,
406
+ inputs=[text_input, audio_prompt, exaggeration, temperature, seed, cfg_weight],
407
+ outputs=[audio_output, status_text]
408
+ )
409
+
410
+ # API Documentation
411
+ gr.Markdown("""
412
+ ## πŸ”Œ API Endpoints
413
+
414
+ ### POST `/api/tts/synthesize`
415
+ Generate speech from text
416
+ ```json
417
+ {
418
+ "text": "Your text here",
419
+ "audio_prompt_url": "URL to reference audio",
420
+ "exaggeration": 0.5,
421
+ "temperature": 0.8,
422
+ "cfg_weight": 0.5,
423
+ "seed": 0
424
+ }
425
+ ```
426
+
427
+ ### GET `/api/audio/{audio_id}`
428
+ Download generated audio file
429
+
430
+ ### GET `/api/audio`
431
+ List all generated audio files
432
+
433
+ ### GET `/health`
434
+ Check service health
435
+ """)
436
+
437
+ # System info
438
+ model_status = "βœ… Loaded" if MODEL else "❌ Not Loaded"
439
+ gr.Markdown(f"""
440
+ ### πŸ“Š System Status
441
+ - **Model**: {model_status}
442
+ - **Device**: {DEVICE}
443
+ - **Generated Files**: {len(audio_cache)}
444
+ """)
445
+
446
+ return demo
447
+
448
+ # Main execution
449
+ if __name__ == "__main__":
450
+ logger.info("πŸŽ‰ Starting ChatterboxTTS Service...")
451
+
452
+ # Model status
453
+ model_status = "βœ… Loaded" if MODEL else "❌ Not Loaded"
454
+ logger.info(f"Model Status: {model_status}")
455
+ logger.info(f"Device: {DEVICE}")
456
+
457
+ if os.getenv("SPACE_ID"):
458
+ # Running in Hugging Face Spaces
459
+ logger.info("🏠 Running in Hugging Face Spaces")
460
+ demo = create_gradio_interface()
461
+ demo.launch(
462
+ server_name="0.0.0.0",
463
+ server_port=7860,
464
+ show_error=True
465
+ )
466
+ else:
467
+ # Local development - run both FastAPI and Gradio
468
+ import uvicorn
469
+ import threading
470
+
471
+ def run_fastapi():
472
+ uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
473
+
474
+ # Start FastAPI in background
475
+ api_thread = threading.Thread(target=run_fastapi, daemon=True)
476
+ api_thread.start()
477
+
478
+ logger.info("🌐 FastAPI: http://localhost:8000")
479
+ logger.info("πŸ“š API Docs: http://localhost:8000/docs")
480
+
481
+ # Start Gradio
482
+ demo = create_gradio_interface()
483
+ demo.launch(share=True)