ethiotech4848 commited on
Commit
eebf3c4
·
verified ·
1 Parent(s): 79f6575

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +405 -0
app.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Optional, Tuple
5
+ import spaces
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import torch
11
+
12
+ from dia.model import Dia
13
+
14
+
15
+ # Load Nari model and config
16
+ print("Loading Nari model...")
17
+ try:
18
+ # Use the function from inference.py
19
+ model = Dia.from_pretrained("nari-labs/Dia-1.6B-0626", compute_dtype="float16")
20
+ except Exception as e:
21
+ print(f"Error loading Nari model: {e}")
22
+ raise
23
+
24
+
25
+ @spaces.GPU(duration=120)
26
+ def run_inference(
27
+ text_input: str,
28
+ audio_prompt_input: Optional[Tuple[int, np.ndarray]],
29
+ transcription_input: Optional[str],
30
+ max_new_tokens: int,
31
+ cfg_scale: float,
32
+ temperature: float,
33
+ top_p: float,
34
+ cfg_filter_top_k: int,
35
+ speed_factor: float,
36
+ ):
37
+ """
38
+ Runs Nari inference using the globally loaded model and provided inputs.
39
+ Uses temporary files for text and audio prompt compatibility with inference.generate.
40
+ """
41
+ # global model, device # Access global model, config, device
42
+
43
+ if not text_input or text_input.isspace():
44
+ raise gr.Error("Text input cannot be empty.")
45
+
46
+ temp_txt_file_path = None
47
+ temp_audio_prompt_path = None
48
+ output_audio = (44100, np.zeros(1, dtype=np.float32))
49
+
50
+ try:
51
+ prompt_path_for_generate = None
52
+ if audio_prompt_input is not None:
53
+ sr, audio_data = audio_prompt_input
54
+ # Enforce maximum duration of 10 seconds for the audio prompt
55
+ duration_sec = len(audio_data) / float(sr) if sr else 0
56
+ if duration_sec > 10.0:
57
+ raise gr.Error("Audio prompt must be 10 seconds or shorter.")
58
+ # Check if audio_data is valid
59
+ if (
60
+ audio_data is None or audio_data.size == 0 or audio_data.max() == 0
61
+ ): # Check for silence/empty
62
+ gr.Warning("Audio prompt seems empty or silent, ignoring prompt.")
63
+ else:
64
+ # Save prompt audio to a temporary WAV file
65
+ with tempfile.NamedTemporaryFile(
66
+ mode="wb", suffix=".wav", delete=False
67
+ ) as f_audio:
68
+ temp_audio_prompt_path = f_audio.name # Store path for cleanup
69
+
70
+ # Basic audio preprocessing for consistency
71
+ # Convert to float32 in [-1, 1] range if integer type
72
+ if np.issubdtype(audio_data.dtype, np.integer):
73
+ max_val = np.iinfo(audio_data.dtype).max
74
+ audio_data = audio_data.astype(np.float32) / max_val
75
+ elif not np.issubdtype(audio_data.dtype, np.floating):
76
+ gr.Warning(
77
+ f"Unsupported audio prompt dtype {audio_data.dtype}, attempting conversion."
78
+ )
79
+ # Attempt conversion, might fail for complex types
80
+ try:
81
+ audio_data = audio_data.astype(np.float32)
82
+ except Exception as conv_e:
83
+ raise gr.Error(
84
+ f"Failed to convert audio prompt to float32: {conv_e}"
85
+ )
86
+
87
+ # Ensure mono (average channels if stereo)
88
+ if audio_data.ndim > 1:
89
+ if audio_data.shape[0] == 2: # Assume (2, N)
90
+ audio_data = np.mean(audio_data, axis=0)
91
+ elif audio_data.shape[1] == 2: # Assume (N, 2)
92
+ audio_data = np.mean(audio_data, axis=1)
93
+ else:
94
+ gr.Warning(
95
+ f"Audio prompt has unexpected shape {audio_data.shape}, taking first channel/axis."
96
+ )
97
+ audio_data = (
98
+ audio_data[0]
99
+ if audio_data.shape[0] < audio_data.shape[1]
100
+ else audio_data[:, 0]
101
+ )
102
+ audio_data = np.ascontiguousarray(
103
+ audio_data
104
+ ) # Ensure contiguous after slicing/mean
105
+
106
+ # Write using soundfile
107
+ try:
108
+ sf.write(
109
+ temp_audio_prompt_path, audio_data, sr, subtype="FLOAT"
110
+ ) # Explicitly use FLOAT subtype
111
+ prompt_path_for_generate = temp_audio_prompt_path
112
+ print(
113
+ f"Created temporary audio prompt file: {temp_audio_prompt_path} (orig sr: {sr})"
114
+ )
115
+ except Exception as write_e:
116
+ print(f"Error writing temporary audio file: {write_e}")
117
+ raise gr.Error(f"Failed to save audio prompt: {write_e}")
118
+
119
+ # 3. Run Generation
120
+
121
+ start_time = time.time()
122
+
123
+ # Use torch.inference_mode() context manager for the generation call
124
+ with torch.inference_mode():
125
+ # Concatenate transcription (if provided) to the main text
126
+ combined_text = (
127
+ text_input.strip() + "\n" + transcription_input.strip()
128
+ if transcription_input and not transcription_input.isspace()
129
+ else text_input
130
+ )
131
+
132
+ output_audio_np = model.generate(
133
+ combined_text,
134
+ max_tokens=max_new_tokens,
135
+ cfg_scale=cfg_scale,
136
+ temperature=temperature,
137
+ top_p=top_p,
138
+ cfg_filter_top_k=cfg_filter_top_k, # Pass the value here
139
+ use_torch_compile=False, # Keep False for Gradio stability
140
+ audio_prompt=prompt_path_for_generate,
141
+ )
142
+
143
+ end_time = time.time()
144
+ print(f"Generation finished in {end_time - start_time:.2f} seconds.")
145
+
146
+ # 4. Convert Codes to Audio
147
+ if output_audio_np is not None:
148
+ # Get sample rate from the loaded DAC model
149
+ output_sr = 44100
150
+
151
+ # --- Slow down audio ---
152
+ original_len = len(output_audio_np)
153
+ # Ensure speed_factor is positive and not excessively small/large to avoid issues
154
+ speed_factor = max(0.1, min(speed_factor, 5.0))
155
+ target_len = int(
156
+ original_len / speed_factor
157
+ ) # Target length based on speed_factor
158
+ if (
159
+ target_len != original_len and target_len > 0
160
+ ): # Only interpolate if length changes and is valid
161
+ x_original = np.arange(original_len)
162
+ x_resampled = np.linspace(0, original_len - 1, target_len)
163
+ resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np)
164
+ output_audio = (
165
+ output_sr,
166
+ resampled_audio_np.astype(np.float32),
167
+ ) # Use resampled audio
168
+ print(
169
+ f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed."
170
+ )
171
+ else:
172
+ output_audio = (
173
+ output_sr,
174
+ output_audio_np,
175
+ ) # Keep original if calculation fails or no change
176
+ print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).")
177
+ # --- End slowdown ---
178
+
179
+ print(
180
+ f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
181
+ )
182
+
183
+ # Explicitly convert to int16 to prevent Gradio warning
184
+ if (
185
+ output_audio[1].dtype == np.float32
186
+ or output_audio[1].dtype == np.float64
187
+ ):
188
+ audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0)
189
+ audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16)
190
+ output_audio = (output_sr, audio_for_gradio)
191
+ print("Converted audio to int16 for Gradio output.")
192
+
193
+ else:
194
+ print("\nGeneration finished, but no valid tokens were produced.")
195
+ # Return default silence
196
+ gr.Warning("Generation produced no output.")
197
+
198
+ except Exception as e:
199
+ print(f"Error during inference: {e}")
200
+ import traceback
201
+
202
+ traceback.print_exc()
203
+ # Re-raise as Gradio error to display nicely in the UI
204
+ raise gr.Error(f"Inference failed: {e}")
205
+
206
+ finally:
207
+ # 5. Cleanup Temporary Files defensively
208
+ if temp_txt_file_path and Path(temp_txt_file_path).exists():
209
+ try:
210
+ Path(temp_txt_file_path).unlink()
211
+ print(f"Deleted temporary text file: {temp_txt_file_path}")
212
+ except OSError as e:
213
+ print(
214
+ f"Warning: Error deleting temporary text file {temp_txt_file_path}: {e}"
215
+ )
216
+ if temp_audio_prompt_path and Path(temp_audio_prompt_path).exists():
217
+ try:
218
+ Path(temp_audio_prompt_path).unlink()
219
+ print(f"Deleted temporary audio prompt file: {temp_audio_prompt_path}")
220
+ except OSError as e:
221
+ print(
222
+ f"Warning: Error deleting temporary audio prompt file {temp_audio_prompt_path}: {e}"
223
+ )
224
+
225
+ return output_audio
226
+
227
+
228
+ # --- Create Gradio Interface ---
229
+ css = """
230
+ #col-container {max-width: 90%; margin-left: auto; margin-right: auto;}
231
+ """
232
+ # Attempt to load default text from example.txt
233
+ default_text = "[S1] Dia is an open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] Wow. Amazing. (laughs) \n[S2] Try it now on Git hub or Hugging Face."
234
+ example_txt_path = Path("./example.txt")
235
+ if example_txt_path.exists():
236
+ try:
237
+ default_text = example_txt_path.read_text(encoding="utf-8").strip()
238
+ if not default_text: # Handle empty example file
239
+ default_text = "Example text file was empty."
240
+ except Exception as e:
241
+ print(f"Warning: Could not read example.txt: {e}")
242
+
243
+
244
+ # Build Gradio UI
245
+ with gr.Blocks(css=css) as demo:
246
+ gr.Markdown("# Nari Text-to-Speech Synthesis")
247
+
248
+ with gr.Row(equal_height=False):
249
+ with gr.Column(scale=1):
250
+ text_input = gr.Textbox(
251
+ label="Input Text",
252
+ placeholder="Enter text here...",
253
+ value=default_text,
254
+ lines=5, # Increased lines
255
+ )
256
+ audio_prompt_input = gr.Audio(
257
+ label="Audio Prompt (≤ 10 s, Optional)",
258
+ show_label=True,
259
+ sources=["upload", "microphone"],
260
+ type="numpy",
261
+ )
262
+ transcription_input = gr.Textbox(
263
+ label="Audio Prompt Transcription (Optional)",
264
+ placeholder="Enter transcription of your audio prompt here...",
265
+ lines=3,
266
+ )
267
+ with gr.Accordion("Generation Parameters", open=False):
268
+ max_new_tokens = gr.Slider(
269
+ label="Max New Tokens (Audio Length)",
270
+ minimum=860,
271
+ maximum=3072,
272
+ value=model.config.decoder_config.max_position_embeddings, # Use config default if available, else fallback
273
+ step=50,
274
+ info="Controls the maximum length of the generated audio (more tokens = longer audio).",
275
+ )
276
+ cfg_scale = gr.Slider(
277
+ label="CFG Scale (Guidance Strength)",
278
+ minimum=1.0,
279
+ maximum=5.0,
280
+ value=3.0, # Default from inference.py
281
+ step=0.1,
282
+ info="Higher values increase adherence to the text prompt.",
283
+ )
284
+ temperature = gr.Slider(
285
+ label="Temperature (Randomness)",
286
+ minimum=1.0,
287
+ maximum=2.5,
288
+ value=1.8, # Default from inference.py
289
+ step=0.05,
290
+ info="Lower values make the output more deterministic, higher values increase randomness.",
291
+ )
292
+ top_p = gr.Slider(
293
+ label="Top P (Nucleus Sampling)",
294
+ minimum=0.70,
295
+ maximum=1.0,
296
+ value=0.95, # Default from inference.py
297
+ step=0.01,
298
+ info="Filters vocabulary to the most likely tokens cumulatively reaching probability P.",
299
+ )
300
+ cfg_filter_top_k = gr.Slider(
301
+ label="CFG Filter Top K",
302
+ minimum=15,
303
+ maximum=100,
304
+ value=45,
305
+ step=1,
306
+ info="Top k filter for CFG guidance.",
307
+ )
308
+ speed_factor_slider = gr.Slider(
309
+ label="Speed Factor",
310
+ minimum=0.8,
311
+ maximum=1.0,
312
+ value=1.0,
313
+ step=0.02,
314
+ info="Adjusts the speed of the generated audio (1.0 = original speed).",
315
+ )
316
+
317
+ run_button = gr.Button("Generate Audio", variant="primary")
318
+
319
+ with gr.Column(scale=1):
320
+ audio_output = gr.Audio(
321
+ label="Generated Audio",
322
+ type="numpy",
323
+ autoplay=False,
324
+ )
325
+
326
+ # Link button click to function
327
+ run_button.click(
328
+ fn=run_inference,
329
+ inputs=[
330
+ text_input,
331
+ audio_prompt_input,
332
+ transcription_input,
333
+ max_new_tokens,
334
+ cfg_scale,
335
+ temperature,
336
+ top_p,
337
+ cfg_filter_top_k,
338
+ speed_factor_slider,
339
+ ],
340
+ outputs=[audio_output], # Add status_output here if using it
341
+ api_name="generate_audio",
342
+ )
343
+
344
+ # Add examples (ensure the prompt path is correct or remove it if example file doesn't exist)
345
+ example_prompt_path = "./example_prompt.mp3" # Adjust if needed
346
+ examples_list = [
347
+ [
348
+ "[S1] Oh fire! Oh my goodness! What's the procedure? What to we do people? The smoke could be coming through an air duct! \n[S2] Oh my god! Okay.. it's happening. Everybody stay calm! \n[S1] What's the procedure... \n[S2] Everybody stay fucking calm!!!... Everybody fucking calm down!!!!! \n[S1] No! No! If you touch the handle, if its hot there might be a fire down the hallway! ",
349
+ None,
350
+ 3072,
351
+ 3.0,
352
+ 1.8,
353
+ 0.95,
354
+ 45,
355
+ 1.0,
356
+ ],
357
+ [
358
+ "[S1] Open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] I'm biased, but I think we clearly won. \n[S2] Hard to disagree. (laughs) \n[S1] Thanks for listening to this demo. \n[S2] Try it now on Git hub and Hugging Face. \n[S1] If you liked our model, please give us a star and share to your friends. \n[S2] This was Nari Labs.",
359
+ example_prompt_path if Path(example_prompt_path).exists() else None,
360
+ 3072,
361
+ 3.0,
362
+ 1.8,
363
+ 0.95,
364
+ 45,
365
+ 1.0,
366
+ ],
367
+ ]
368
+
369
+ if examples_list:
370
+ gr.Examples(
371
+ examples=[
372
+ [
373
+ ex[0], # text
374
+ ex[1], # audio prompt path
375
+ "", # transcription placeholder
376
+ *ex[2:],
377
+ ]
378
+ for ex in examples_list
379
+ ],
380
+ inputs=[
381
+ text_input,
382
+ audio_prompt_input,
383
+ transcription_input,
384
+ max_new_tokens,
385
+ cfg_scale,
386
+ temperature,
387
+ top_p,
388
+ cfg_filter_top_k,
389
+ speed_factor_slider,
390
+ ],
391
+ outputs=[audio_output],
392
+ fn=run_inference,
393
+ cache_examples=False,
394
+ label="Examples (Click to Run)",
395
+ )
396
+ else:
397
+ gr.Markdown("_(No examples configured or example prompt file missing)_")
398
+
399
+ # --- Launch the App ---
400
+ if __name__ == "__main__":
401
+ print("Launching Gradio interface...")
402
+
403
+ # set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values
404
+ # use `GRADIO_SERVER_NAME=0.0.0.0` for Docker
405
+ demo.launch()