Michael Hu commited on
Commit
4b0381b
·
1 Parent(s): 875a169

add sample dia app file

Browse files
Files changed (1) hide show
  1. dia_app_gradio.py +378 -0
dia_app_gradio.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Optional, Tuple
5
+ import spaces
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import torch
11
+
12
+ from dia.model import Dia
13
+
14
+
15
+ # Load Nari model and config
16
+ print("Loading Nari model...")
17
+ try:
18
+ # Use the function from inference.py
19
+ model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32")
20
+ except Exception as e:
21
+ print(f"Error loading Nari model: {e}")
22
+ raise
23
+
24
+
25
+ @spaces.GPU
26
+ def run_inference(
27
+ text_input: str,
28
+ audio_prompt_input: Optional[Tuple[int, np.ndarray]],
29
+ max_new_tokens: int,
30
+ cfg_scale: float,
31
+ temperature: float,
32
+ top_p: float,
33
+ cfg_filter_top_k: int,
34
+ speed_factor: float,
35
+ ):
36
+ """
37
+ Runs Nari inference using the globally loaded model and provided inputs.
38
+ Uses temporary files for text and audio prompt compatibility with inference.generate.
39
+ """
40
+ # global model, device # Access global model, config, device
41
+
42
+ if not text_input or text_input.isspace():
43
+ raise gr.Error("Text input cannot be empty.")
44
+
45
+ temp_txt_file_path = None
46
+ temp_audio_prompt_path = None
47
+ output_audio = (44100, np.zeros(1, dtype=np.float32))
48
+
49
+ try:
50
+ prompt_path_for_generate = None
51
+ if audio_prompt_input is not None:
52
+ sr, audio_data = audio_prompt_input
53
+ # Check if audio_data is valid
54
+ if (
55
+ audio_data is None or audio_data.size == 0 or audio_data.max() == 0
56
+ ): # Check for silence/empty
57
+ gr.Warning("Audio prompt seems empty or silent, ignoring prompt.")
58
+ else:
59
+ # Save prompt audio to a temporary WAV file
60
+ with tempfile.NamedTemporaryFile(
61
+ mode="wb", suffix=".wav", delete=False
62
+ ) as f_audio:
63
+ temp_audio_prompt_path = f_audio.name # Store path for cleanup
64
+
65
+ # Basic audio preprocessing for consistency
66
+ # Convert to float32 in [-1, 1] range if integer type
67
+ if np.issubdtype(audio_data.dtype, np.integer):
68
+ max_val = np.iinfo(audio_data.dtype).max
69
+ audio_data = audio_data.astype(np.float32) / max_val
70
+ elif not np.issubdtype(audio_data.dtype, np.floating):
71
+ gr.Warning(
72
+ f"Unsupported audio prompt dtype {audio_data.dtype}, attempting conversion."
73
+ )
74
+ # Attempt conversion, might fail for complex types
75
+ try:
76
+ audio_data = audio_data.astype(np.float32)
77
+ except Exception as conv_e:
78
+ raise gr.Error(
79
+ f"Failed to convert audio prompt to float32: {conv_e}"
80
+ )
81
+
82
+ # Ensure mono (average channels if stereo)
83
+ if audio_data.ndim > 1:
84
+ if audio_data.shape[0] == 2: # Assume (2, N)
85
+ audio_data = np.mean(audio_data, axis=0)
86
+ elif audio_data.shape[1] == 2: # Assume (N, 2)
87
+ audio_data = np.mean(audio_data, axis=1)
88
+ else:
89
+ gr.Warning(
90
+ f"Audio prompt has unexpected shape {audio_data.shape}, taking first channel/axis."
91
+ )
92
+ audio_data = (
93
+ audio_data[0]
94
+ if audio_data.shape[0] < audio_data.shape[1]
95
+ else audio_data[:, 0]
96
+ )
97
+ audio_data = np.ascontiguousarray(
98
+ audio_data
99
+ ) # Ensure contiguous after slicing/mean
100
+
101
+ # Write using soundfile
102
+ try:
103
+ sf.write(
104
+ temp_audio_prompt_path, audio_data, sr, subtype="FLOAT"
105
+ ) # Explicitly use FLOAT subtype
106
+ prompt_path_for_generate = temp_audio_prompt_path
107
+ print(
108
+ f"Created temporary audio prompt file: {temp_audio_prompt_path} (orig sr: {sr})"
109
+ )
110
+ except Exception as write_e:
111
+ print(f"Error writing temporary audio file: {write_e}")
112
+ raise gr.Error(f"Failed to save audio prompt: {write_e}")
113
+
114
+ # 3. Run Generation
115
+
116
+ start_time = time.time()
117
+
118
+ # Use torch.inference_mode() context manager for the generation call
119
+ with torch.inference_mode():
120
+ output_audio_np = model.generate(
121
+ text_input,
122
+ max_tokens=max_new_tokens,
123
+ cfg_scale=cfg_scale,
124
+ temperature=temperature,
125
+ top_p=top_p,
126
+ cfg_filter_top_k=cfg_filter_top_k, # Pass the value here
127
+ use_torch_compile=False, # Keep False for Gradio stability
128
+ audio_prompt=prompt_path_for_generate,
129
+ )
130
+
131
+ end_time = time.time()
132
+ print(f"Generation finished in {end_time - start_time:.2f} seconds.")
133
+
134
+ # 4. Convert Codes to Audio
135
+ if output_audio_np is not None:
136
+ # Get sample rate from the loaded DAC model
137
+ output_sr = 44100
138
+
139
+ # --- Slow down audio ---
140
+ original_len = len(output_audio_np)
141
+ # Ensure speed_factor is positive and not excessively small/large to avoid issues
142
+ speed_factor = max(0.1, min(speed_factor, 5.0))
143
+ target_len = int(
144
+ original_len / speed_factor
145
+ ) # Target length based on speed_factor
146
+ if (
147
+ target_len != original_len and target_len > 0
148
+ ): # Only interpolate if length changes and is valid
149
+ x_original = np.arange(original_len)
150
+ x_resampled = np.linspace(0, original_len - 1, target_len)
151
+ resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np)
152
+ output_audio = (
153
+ output_sr,
154
+ resampled_audio_np.astype(np.float32),
155
+ ) # Use resampled audio
156
+ print(
157
+ f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed."
158
+ )
159
+ else:
160
+ output_audio = (
161
+ output_sr,
162
+ output_audio_np,
163
+ ) # Keep original if calculation fails or no change
164
+ print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).")
165
+ # --- End slowdown ---
166
+
167
+ print(
168
+ f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
169
+ )
170
+
171
+ # Explicitly convert to int16 to prevent Gradio warning
172
+ if (
173
+ output_audio[1].dtype == np.float32
174
+ or output_audio[1].dtype == np.float64
175
+ ):
176
+ audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0)
177
+ audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16)
178
+ output_audio = (output_sr, audio_for_gradio)
179
+ print("Converted audio to int16 for Gradio output.")
180
+
181
+ else:
182
+ print("\nGeneration finished, but no valid tokens were produced.")
183
+ # Return default silence
184
+ gr.Warning("Generation produced no output.")
185
+
186
+ except Exception as e:
187
+ print(f"Error during inference: {e}")
188
+ import traceback
189
+
190
+ traceback.print_exc()
191
+ # Re-raise as Gradio error to display nicely in the UI
192
+ raise gr.Error(f"Inference failed: {e}")
193
+
194
+ finally:
195
+ # 5. Cleanup Temporary Files defensively
196
+ if temp_txt_file_path and Path(temp_txt_file_path).exists():
197
+ try:
198
+ Path(temp_txt_file_path).unlink()
199
+ print(f"Deleted temporary text file: {temp_txt_file_path}")
200
+ except OSError as e:
201
+ print(
202
+ f"Warning: Error deleting temporary text file {temp_txt_file_path}: {e}"
203
+ )
204
+ if temp_audio_prompt_path and Path(temp_audio_prompt_path).exists():
205
+ try:
206
+ Path(temp_audio_prompt_path).unlink()
207
+ print(f"Deleted temporary audio prompt file: {temp_audio_prompt_path}")
208
+ except OSError as e:
209
+ print(
210
+ f"Warning: Error deleting temporary audio prompt file {temp_audio_prompt_path}: {e}"
211
+ )
212
+
213
+ return output_audio
214
+
215
+
216
+ # --- Create Gradio Interface ---
217
+ css = """
218
+ #col-container {max-width: 90%; margin-left: auto; margin-right: auto;}
219
+ """
220
+ # Attempt to load default text from example.txt
221
+ default_text = "[S1] Dia is an open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] Wow. Amazing. (laughs) \n[S2] Try it now on Git hub or Hugging Face."
222
+ example_txt_path = Path("./example.txt")
223
+ if example_txt_path.exists():
224
+ try:
225
+ default_text = example_txt_path.read_text(encoding="utf-8").strip()
226
+ if not default_text: # Handle empty example file
227
+ default_text = "Example text file was empty."
228
+ except Exception as e:
229
+ print(f"Warning: Could not read example.txt: {e}")
230
+
231
+
232
+ # Build Gradio UI
233
+ with gr.Blocks(css=css) as demo:
234
+ gr.Markdown("# Nari Text-to-Speech Synthesis")
235
+
236
+ with gr.Row(equal_height=False):
237
+ with gr.Column(scale=1):
238
+ text_input = gr.Textbox(
239
+ label="Input Text",
240
+ placeholder="Enter text here...",
241
+ value=default_text,
242
+ lines=5, # Increased lines
243
+ )
244
+ audio_prompt_input = gr.Audio(
245
+ label="Audio Prompt (Optional)",
246
+ show_label=True,
247
+ sources=["upload", "microphone"],
248
+ type="numpy",
249
+ )
250
+ with gr.Accordion("Generation Parameters", open=False):
251
+ max_new_tokens = gr.Slider(
252
+ label="Max New Tokens (Audio Length)",
253
+ minimum=860,
254
+ maximum=3072,
255
+ value=model.config.data.audio_length, # Use config default if available, else fallback
256
+ step=50,
257
+ info="Controls the maximum length of the generated audio (more tokens = longer audio).",
258
+ )
259
+ cfg_scale = gr.Slider(
260
+ label="CFG Scale (Guidance Strength)",
261
+ minimum=1.0,
262
+ maximum=5.0,
263
+ value=3.0, # Default from inference.py
264
+ step=0.1,
265
+ info="Higher values increase adherence to the text prompt.",
266
+ )
267
+ temperature = gr.Slider(
268
+ label="Temperature (Randomness)",
269
+ minimum=1.0,
270
+ maximum=1.5,
271
+ value=1.3, # Default from inference.py
272
+ step=0.05,
273
+ info="Lower values make the output more deterministic, higher values increase randomness.",
274
+ )
275
+ top_p = gr.Slider(
276
+ label="Top P (Nucleus Sampling)",
277
+ minimum=0.80,
278
+ maximum=1.0,
279
+ value=0.95, # Default from inference.py
280
+ step=0.01,
281
+ info="Filters vocabulary to the most likely tokens cumulatively reaching probability P.",
282
+ )
283
+ cfg_filter_top_k = gr.Slider(
284
+ label="CFG Filter Top K",
285
+ minimum=15,
286
+ maximum=50,
287
+ value=30,
288
+ step=1,
289
+ info="Top k filter for CFG guidance.",
290
+ )
291
+ speed_factor_slider = gr.Slider(
292
+ label="Speed Factor",
293
+ minimum=0.8,
294
+ maximum=1.0,
295
+ value=0.94,
296
+ step=0.02,
297
+ info="Adjusts the speed of the generated audio (1.0 = original speed).",
298
+ )
299
+
300
+ run_button = gr.Button("Generate Audio", variant="primary")
301
+
302
+ with gr.Column(scale=1):
303
+ audio_output = gr.Audio(
304
+ label="Generated Audio",
305
+ type="numpy",
306
+ autoplay=False,
307
+ )
308
+
309
+ # Link button click to function
310
+ run_button.click(
311
+ fn=run_inference,
312
+ inputs=[
313
+ text_input,
314
+ audio_prompt_input,
315
+ max_new_tokens,
316
+ cfg_scale,
317
+ temperature,
318
+ top_p,
319
+ cfg_filter_top_k,
320
+ speed_factor_slider,
321
+ ],
322
+ outputs=[audio_output], # Add status_output here if using it
323
+ api_name="generate_audio",
324
+ )
325
+
326
+ # Add examples (ensure the prompt path is correct or remove it if example file doesn't exist)
327
+ example_prompt_path = "./example_prompt.mp3" # Adjust if needed
328
+ examples_list = [
329
+ [
330
+ "[S1] Oh fire! Oh my goodness! What's the procedure? What to we do people? The smoke could be coming through an air duct! \n[S2] Oh my god! Okay.. it's happening. Everybody stay calm! \n[S1] What's the procedure... \n[S2] Everybody stay fucking calm!!!... Everybody fucking calm down!!!!! \n[S1] No! No! If you touch the handle, if its hot there might be a fire down the hallway! ",
331
+ None,
332
+ 3072,
333
+ 3.0,
334
+ 1.3,
335
+ 0.95,
336
+ 35,
337
+ 0.94,
338
+ ],
339
+ [
340
+ "[S1] Open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] I'm biased, but I think we clearly won. \n[S2] Hard to disagree. (laughs) \n[S1] Thanks for listening to this demo. \n[S2] Try it now on Git hub and Hugging Face. \n[S1] If you liked our model, please give us a star and share to your friends. \n[S2] This was Nari Labs.",
341
+ example_prompt_path if Path(example_prompt_path).exists() else None,
342
+ 3072,
343
+ 3.0,
344
+ 1.3,
345
+ 0.95,
346
+ 35,
347
+ 0.94,
348
+ ],
349
+ ]
350
+
351
+ if examples_list:
352
+ gr.Examples(
353
+ examples=examples_list,
354
+ inputs=[
355
+ text_input,
356
+ audio_prompt_input,
357
+ max_new_tokens,
358
+ cfg_scale,
359
+ temperature,
360
+ top_p,
361
+ cfg_filter_top_k,
362
+ speed_factor_slider,
363
+ ],
364
+ outputs=[audio_output],
365
+ fn=run_inference,
366
+ cache_examples=False,
367
+ label="Examples (Click to Run)",
368
+ )
369
+ else:
370
+ gr.Markdown("_(No examples configured or example prompt file missing)_")
371
+
372
+ # --- Launch the App ---
373
+ if __name__ == "__main__":
374
+ print("Launching Gradio interface...")
375
+
376
+ # set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values
377
+ # use `GRADIO_SERVER_NAME=0.0.0.0` for Docker
378
+ demo.launch()