Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,189 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from dia.model import Dia
|
5 |
+
import warnings
|
6 |
|
7 |
+
# Suppress warnings for cleaner output
|
8 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
9 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
10 |
+
|
11 |
+
# Global model variable
|
12 |
+
model = None
|
13 |
+
|
14 |
+
def load_model_once():
|
15 |
+
"""Load the Dia model once and cache it globally"""
|
16 |
+
global model
|
17 |
+
if model is None:
|
18 |
+
print("Loading Dia model... This may take a few minutes.")
|
19 |
+
try:
|
20 |
+
# Load model with correct parameters for Dia
|
21 |
+
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32")
|
22 |
+
|
23 |
+
# Move model to GPU if available
|
24 |
+
if torch.cuda.is_available():
|
25 |
+
model = model.cuda()
|
26 |
+
print(f"Model loaded on GPU: {torch.cuda.get_device_name()}")
|
27 |
+
else:
|
28 |
+
print("Model loaded on CPU")
|
29 |
+
|
30 |
+
print("Model loaded successfully!")
|
31 |
+
|
32 |
+
except Exception as e:
|
33 |
+
print(f"Error loading model: {e}")
|
34 |
+
raise e
|
35 |
+
|
36 |
+
return model
|
37 |
+
|
38 |
+
def generate_audio(text, seed=42):
|
39 |
+
"""Generate audio from text input with error handling"""
|
40 |
+
try:
|
41 |
+
# Clear GPU cache before generation
|
42 |
+
if torch.cuda.is_available():
|
43 |
+
torch.cuda.empty_cache()
|
44 |
+
|
45 |
+
current_model = load_model_once()
|
46 |
+
|
47 |
+
# Validate input
|
48 |
+
if not text or not text.strip():
|
49 |
+
return None, "β Please enter some text"
|
50 |
+
|
51 |
+
# Clean and format text
|
52 |
+
text = text.strip()
|
53 |
+
if not text.startswith('[S1]') and not text.startswith('[S2]'):
|
54 |
+
text = '[S1] ' + text
|
55 |
+
|
56 |
+
# Set seed for reproducibility
|
57 |
+
if seed:
|
58 |
+
torch.manual_seed(int(seed))
|
59 |
+
if torch.cuda.is_available():
|
60 |
+
torch.cuda.manual_seed(int(seed))
|
61 |
+
|
62 |
+
print(f"Generating speech for: {text[:100]}...")
|
63 |
+
|
64 |
+
# Generate audio - disable torch compile for stability
|
65 |
+
with torch.no_grad():
|
66 |
+
audio_output = current_model.generate(
|
67 |
+
text,
|
68 |
+
use_torch_compile=False, # Disabled for T4 compatibility
|
69 |
+
verbose=False
|
70 |
+
)
|
71 |
+
|
72 |
+
# Ensure audio_output is numpy array
|
73 |
+
if isinstance(audio_output, torch.Tensor):
|
74 |
+
audio_output = audio_output.cpu().numpy()
|
75 |
+
|
76 |
+
# Normalize audio to prevent clipping
|
77 |
+
if len(audio_output) > 0:
|
78 |
+
max_val = np.max(np.abs(audio_output))
|
79 |
+
if max_val > 1.0:
|
80 |
+
audio_output = audio_output / max_val * 0.95
|
81 |
+
|
82 |
+
print("β
Audio generated successfully!")
|
83 |
+
return (44100, audio_output), "β
Audio generated successfully!"
|
84 |
+
|
85 |
+
except torch.cuda.OutOfMemoryError:
|
86 |
+
# Handle GPU memory issues
|
87 |
+
if torch.cuda.is_available():
|
88 |
+
torch.cuda.empty_cache()
|
89 |
+
error_msg = "β GPU memory error. Try shorter text or restart the space."
|
90 |
+
print(error_msg)
|
91 |
+
return None, error_msg
|
92 |
+
|
93 |
+
except Exception as e:
|
94 |
+
error_msg = f"β Error: {str(e)}"
|
95 |
+
print(error_msg)
|
96 |
+
return None, error_msg
|
97 |
+
|
98 |
+
# Create the Gradio interface - simplified to avoid OAuth triggers
|
99 |
+
demo = gr.Blocks(title="Dia TTS Demo")
|
100 |
+
|
101 |
+
with demo:
|
102 |
+
gr.HTML("""
|
103 |
+
<div style="text-align: center; padding: 20px;">
|
104 |
+
<h1>ποΈ Dia TTS - Ultra-Realistic Text-to-Speech</h1>
|
105 |
+
<p style="font-size: 18px; color: #666;">
|
106 |
+
Generate multi-speaker, emotion-aware dialogue using the Dia 1.6B model
|
107 |
+
</p>
|
108 |
+
</div>
|
109 |
+
""")
|
110 |
+
|
111 |
+
with gr.Row():
|
112 |
+
with gr.Column():
|
113 |
+
text_input = gr.Textbox(
|
114 |
+
label="π Text to Speech",
|
115 |
+
placeholder="[S1] Hello there! How are you today? [S2] I'm doing great, thanks for asking! (laughs)",
|
116 |
+
lines=6,
|
117 |
+
value="[S1] Welcome to the Dia TTS demo! [S2] This is amazing technology!",
|
118 |
+
info="Use [S1] and [S2] for different speakers. Add emotions like (laughs), (sighs), etc."
|
119 |
+
)
|
120 |
+
|
121 |
+
seed_input = gr.Number(
|
122 |
+
label="π² Random Seed",
|
123 |
+
value=42,
|
124 |
+
precision=0,
|
125 |
+
info="Same seed = consistent voices"
|
126 |
+
)
|
127 |
+
|
128 |
+
generate_btn = gr.Button("π΅ Generate Speech", variant="primary")
|
129 |
+
|
130 |
+
with gr.Column():
|
131 |
+
audio_output = gr.Audio(
|
132 |
+
label="π Generated Audio",
|
133 |
+
type="numpy"
|
134 |
+
)
|
135 |
+
|
136 |
+
status_text = gr.Textbox(
|
137 |
+
label="π Status",
|
138 |
+
interactive=False,
|
139 |
+
lines=2
|
140 |
+
)
|
141 |
+
|
142 |
+
# Connect the button to the function
|
143 |
+
generate_btn.click(
|
144 |
+
fn=generate_audio,
|
145 |
+
inputs=[text_input, seed_input],
|
146 |
+
outputs=[audio_output, status_text]
|
147 |
+
)
|
148 |
+
|
149 |
+
# Add example buttons
|
150 |
+
with gr.Row():
|
151 |
+
example_btn1 = gr.Button("π» Podcast", size="sm")
|
152 |
+
example_btn2 = gr.Button("π Chat", size="sm")
|
153 |
+
example_btn3 = gr.Button("π Drama", size="sm")
|
154 |
+
|
155 |
+
# Example button functions
|
156 |
+
example_btn1.click(
|
157 |
+
lambda: "[S1] Welcome to our podcast! [S2] Thanks for having me on the show!",
|
158 |
+
outputs=text_input
|
159 |
+
)
|
160 |
|
161 |
+
example_btn2.click(
|
162 |
+
lambda: "[S1] Did you see the game? [S2] Yes! (laughs) It was incredible!",
|
163 |
+
outputs=text_input
|
164 |
+
)
|
165 |
+
|
166 |
+
example_btn3.click(
|
167 |
+
lambda: "[S1] I can't believe you're leaving. (sighs) [S2] I know, it's hard. (sad)",
|
168 |
+
outputs=text_input
|
169 |
+
)
|
170 |
+
|
171 |
+
# Usage instructions
|
172 |
+
gr.HTML("""
|
173 |
+
<div style="margin-top: 20px; padding: 15px; background: #f0f8ff; border-radius: 8px;">
|
174 |
+
<h3>π‘ Usage Tips:</h3>
|
175 |
+
<ul>
|
176 |
+
<li><strong>Speaker Tags:</strong> Use [S1] and [S2] to switch between speakers</li>
|
177 |
+
<li><strong>Emotions:</strong> Add (laughs), (sighs), (excited), (whispers), (sad), etc.</li>
|
178 |
+
<li><strong>Length:</strong> Keep text moderate length (5-20 seconds of speech works best)</li>
|
179 |
+
<li><strong>Seeds:</strong> Use the same seed number for consistent voice characteristics</li>
|
180 |
+
</ul>
|
181 |
+
|
182 |
+
<p><strong>Supported Emotions:</strong> (laughs), (sighs), (gasps), (excited), (sad), (angry),
|
183 |
+
(surprised), (whispers), (shouts), (coughs), (clears throat), (sniffs), (chuckles), (groans)</p>
|
184 |
+
</div>
|
185 |
+
""")
|
186 |
+
|
187 |
+
# Launch with basic configuration
|
188 |
+
if __name__ == "__main__":
|
189 |
+
demo.launch()
|