m-ric HF Staff commited on
Commit
886bd4b
·
verified ·
1 Parent(s): d316fea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -5
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import queue
2
  import threading
 
3
  import os
4
  import gradio as gr
5
  from dia.model import Dia
@@ -12,7 +13,7 @@ PODCAST_SUBJECT = "The future of AI and its impact on society"
12
 
13
  # Initialize the inference client
14
  client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", provider="cerebras", token=os.getenv("HF_TOKEN"))
15
- model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16")
16
 
17
  # Queue for audio streaming
18
  audio_queue = queue.Queue()
@@ -43,7 +44,56 @@ def split_podcast_into_chunks(podcast_text, chunk_size=3):
43
 
44
  return chunks
45
 
46
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def process_audio_chunks(podcast_text):
48
  chunks = split_podcast_into_chunks(podcast_text)
49
  sample_rate = 44100 # Modified from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py has 44100
@@ -51,9 +101,15 @@ def process_audio_chunks(podcast_text):
51
  if stop_signal.is_set():
52
  break
53
  set_seed(42)
54
- raw_audio = model.generate(chunk, use_torch_compile=True, verbose=False)
55
- audio_chunk = np.array(raw_audio, dtype=np.float32)
56
- audio_queue.put((sample_rate, audio_chunk))
 
 
 
 
 
 
57
 
58
  audio_queue.put(None)
59
 
 
1
  import queue
2
  import threading
3
+ import spaces
4
  import os
5
  import gradio as gr
6
  from dia.model import Dia
 
13
 
14
  # Initialize the inference client
15
  client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", provider="cerebras", token=os.getenv("HF_TOKEN"))
16
+ model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32")
17
 
18
  # Queue for audio streaming
19
  audio_queue = queue.Queue()
 
44
 
45
  return chunks
46
 
47
+ def postprocess_audio(output_audio_np, speed_factor: float=0.94):
48
+ """Taken from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py"""
49
+ # Get sample rate from the loaded DAC model
50
+ output_sr = 44100
51
+
52
+ # --- Slow down audio ---
53
+ original_len = len(output_audio_np)
54
+ # Ensure speed_factor is positive and not excessively small/large to avoid issues
55
+ speed_factor = max(0.1, min(speed_factor, 5.0))
56
+ target_len = int(
57
+ original_len / speed_factor
58
+ ) # Target length based on speed_factor
59
+ if (
60
+ target_len != original_len and target_len > 0
61
+ ): # Only interpolate if length changes and is valid
62
+ x_original = np.arange(original_len)
63
+ x_resampled = np.linspace(0, original_len - 1, target_len)
64
+ resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np)
65
+ output_audio = (
66
+ output_sr,
67
+ resampled_audio_np.astype(np.float32),
68
+ ) # Use resampled audio
69
+ print(
70
+ f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed."
71
+ )
72
+ else:
73
+ output_audio = (
74
+ output_sr,
75
+ output_audio_np,
76
+ ) # Keep original if calculation fails or no change
77
+ print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).")
78
+ # --- End slowdown ---
79
+
80
+ print(
81
+ f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
82
+ )
83
+
84
+ # Explicitly convert to int16 to prevent Gradio warning
85
+ if (
86
+ output_audio[1].dtype == np.float32
87
+ or output_audio[1].dtype == np.float64
88
+ ):
89
+ audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0)
90
+ audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16)
91
+ output_audio = (output_sr, audio_for_gradio)
92
+ print("Converted audio to int16 for Gradio output.")
93
+ return output_audio
94
+
95
+
96
+ @spaces.GPU
97
  def process_audio_chunks(podcast_text):
98
  chunks = split_podcast_into_chunks(podcast_text)
99
  sample_rate = 44100 # Modified from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py has 44100
 
101
  if stop_signal.is_set():
102
  break
103
  set_seed(42)
104
+ raw_audio = model.generate(
105
+ chunk,
106
+ use_torch_compile=False,
107
+ verbose=False,
108
+ temperature=1.3
109
+ top_p=0.95
110
+ )
111
+ audio_chunk_np = np.array(raw_audio, dtype=np.float32)
112
+ audio_queue.put(postprocess_audio(audio_chunk_np))
113
 
114
  audio_queue.put(None)
115