m-ric HF Staff commited on
Commit
4af8987
·
1 Parent(s): 5da485d

Working Kokoro

Browse files
Files changed (1) hide show
  1. app.py +86 -95
app.py CHANGED
@@ -2,27 +2,52 @@ import queue
2
  import threading
3
  import spaces
4
  import os
 
 
5
  import gradio as gr
6
- from dia.model import Dia
7
- from huggingface_hub import InferenceClient
8
  import numpy as np
 
9
  from transformers import set_seed
10
- import io, soundfile as sf
11
-
12
 
13
- # Hardcoded podcast subject
 
 
14
  PODCAST_SUBJECT = "The future of AI and its impact on society"
15
 
16
- # Initialize the inference client
17
- client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", provider="cerebras", token=os.getenv("HF_TOKEN"))
18
- model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="bfloat16")
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Queue for audio streaming
21
- audio_queue = queue.Queue()
 
 
 
 
 
 
 
22
  stop_signal = threading.Event()
23
 
24
 
25
- def generate_podcast_text(subject):
 
 
26
  prompt = f"""Generate a podcast told by 2 hosts about {subject}.
27
  The podcast should be an insightful discussion, with some amount of playful banter.
28
  Separate dialog as follows using [S1] for the male host and [S2] for the female host, for instance:
@@ -32,87 +57,53 @@ Separate dialog as follows using [S1] for the male host and [S2] for the female
32
  [S2] Great.
33
  Now go on, make 5 minutes of podcast.
34
  """
35
- response = client.chat_completion([{"role": "user", "content": prompt}], max_tokens=1000)
36
- return response.choices[0].message.content
37
-
38
-
39
- def split_podcast_into_chunks(podcast_text, chunk_size=3):
40
- lines = podcast_text.strip().split("\n")
41
- return ["\n".join(lines[i : i + chunk_size]) for i in range(0, len(lines), chunk_size)]
42
-
43
- def postprocess_audio(output_audio_np, speed_factor: float=0.8):
44
- """Taken from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py"""
45
- # Get sample rate from the loaded DAC model
46
- output_sr = 44100
47
-
48
- # --- Slow down audio ---
49
- original_len = len(output_audio_np)
50
- # Ensure speed_factor is positive and not excessively small/large to avoid issues
51
- speed_factor = max(0.1, min(speed_factor, 5.0))
52
- target_len = int(
53
- original_len / speed_factor
54
- ) # Target length based on speed_factor
55
- if (
56
- target_len != original_len and target_len > 0
57
- ): # Only interpolate if length changes and is valid
58
- x_original = np.arange(original_len)
59
- x_resampled = np.linspace(0, original_len - 1, target_len)
60
- resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np)
61
- output_audio = (
62
- output_sr,
63
- resampled_audio_np.astype(np.float32),
64
- ) # Use resampled audio
65
- print(
66
- f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed."
67
- )
68
- else:
69
- output_audio = (
70
- output_sr,
71
- output_audio_np,
72
- ) # Keep original if calculation fails or no change
73
- print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).")
74
- # --- End slowdown ---
75
-
76
- print(
77
- f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
78
  )
 
79
 
80
- # Explicitly convert to int16 to prevent Gradio warning
81
- if (
82
- output_audio[1].dtype == np.float32
83
- or output_audio[1].dtype == np.float64
84
- ):
85
- audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0)
86
- audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16)
87
- output_audio = (output_sr, audio_for_gradio)
88
- print("Converted audio to int16 for Gradio output.")
89
- return output_audio
90
 
 
 
 
91
 
92
- @spaces.GPU
93
- def process_audio_chunks(podcast_text):
94
- chunks = split_podcast_into_chunks(podcast_text)
95
- sample_rate = 44100 # Modified from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py has 44100
96
- for chunk in chunks:
97
- print(f"Processing chunk: {chunk}")
98
  if stop_signal.is_set():
99
  break
100
- set_seed(42)
101
- raw_audio = model.generate(
102
- chunk,
103
- use_torch_compile=False, # To avoid gradio instability
104
- verbose=False,
105
- temperature=1.3,
106
- top_p=0.95,
107
- )
108
- audio_chunk_np = np.array(raw_audio, dtype=np.float32)
109
- audio_queue.put(postprocess_audio(audio_chunk_np))
110
-
111
- audio_queue.put(None)
112
-
113
-
114
- def stream_audio_generator(podcast_text):
115
- """Creates a generator that yields audio chunks for streaming"""
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  stop_signal.clear()
117
  threading.Thread(target=process_audio_chunks, args=(podcast_text,)).start()
118
 
@@ -120,15 +111,14 @@ def stream_audio_generator(podcast_text):
120
  chunk = audio_queue.get()
121
  if chunk is None:
122
  break
123
- sr, data = chunk # the tuple you produced earlier
 
124
 
125
- # Encode the numpy array into a WAV blob
126
  buf = io.BytesIO()
127
- sf.write(buf, data.astype(np.float32) / 32768.0, sr, format="wav")
128
  buf.seek(0)
129
- buffer = buf.getvalue()
130
- print("PRINTING BUFFER:", buffer)
131
- yield buffer# <-- bytes, so the browser can play it
132
 
133
 
134
  def stop_generation():
@@ -137,8 +127,7 @@ def stop_generation():
137
 
138
 
139
  def generate_podcast():
140
- podcast_text = generate_podcast_text(PODCAST_SUBJECT)
141
- return podcast_text
142
 
143
 
144
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
@@ -147,7 +136,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
147
  with gr.Row():
148
  with gr.Column(scale=2):
149
  gr.Markdown(f"## Current Topic: {PODCAST_SUBJECT}")
150
- gr.Markdown("This app generates a podcast discussion between two hosts about the specified topic.")
 
 
151
 
152
  generate_btn = gr.Button("Generate Podcast Script", variant="primary")
153
  podcast_output = gr.Textbox(label="Generated Podcast Script", lines=15)
 
2
  import threading
3
  import spaces
4
  import os
5
+ import io
6
+ import soundfile as sf
7
  import gradio as gr
 
 
8
  import numpy as np
9
+ import torch
10
  from transformers import set_seed
11
+ from huggingface_hub import InferenceClient
12
+ from kokoro import KModel, KPipeline
13
 
14
+ # -----------------------------------------------------------------------------
15
+ # Hard‑coded podcast subject
16
+ # -----------------------------------------------------------------------------
17
  PODCAST_SUBJECT = "The future of AI and its impact on society"
18
 
19
+ # -----------------------------------------------------------------------------
20
+ # LLM that writes the script (unchanged)
21
+ # -----------------------------------------------------------------------------
22
+ client = InferenceClient(
23
+ "meta-llama/Llama-3.3-70B-Instruct",
24
+ provider="cerebras",
25
+ token=os.getenv("HF_TOKEN"),
26
+ )
27
+
28
+ # -----------------------------------------------------------------------------
29
+ # Kokoro TTS setup (replaces Dia)
30
+ # -----------------------------------------------------------------------------
31
+ CUDA_AVAILABLE = torch.cuda.is_available()
32
+
33
+ kmodel = KModel().to("cuda" if CUDA_AVAILABLE else "cpu").eval()
34
+ kpipeline = KPipeline(lang_code="a") # English voices
35
 
36
+ MALE_VOICE = "am_michael" # [S1]
37
+ FEMALE_VOICE = "af_heart" # [S2]
38
+
39
+ # Pre‑warm voices to avoid first‑call latency
40
+ for v in (MALE_VOICE, FEMALE_VOICE):
41
+ kpipeline.load_voice(v)
42
+
43
+
44
+ audio_queue: queue.Queue[tuple[int, np.ndarray] | None] = queue.Queue()
45
  stop_signal = threading.Event()
46
 
47
 
48
+
49
+ def generate_podcast_text(subject: str) -> str:
50
+ """Ask the LLM for a ~5‑minute two‑host script."""
51
  prompt = f"""Generate a podcast told by 2 hosts about {subject}.
52
  The podcast should be an insightful discussion, with some amount of playful banter.
53
  Separate dialog as follows using [S1] for the male host and [S2] for the female host, for instance:
 
57
  [S2] Great.
58
  Now go on, make 5 minutes of podcast.
59
  """
60
+ response = client.chat_completion(
61
+ [{"role": "user", "content": prompt}],
62
+ max_tokens=1000,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
+ return response.choices[0].message.content
65
 
66
+ @spaces.GPU
67
+ def process_audio_chunks(podcast_text: str, speed: float = 1.0) -> None:
68
+ """Read each line, pick voice via tag, send chunks to the queue."""
69
+ lines = [l for l in podcast_text.strip().splitlines() if l.strip()]
 
 
 
 
 
 
70
 
71
+ pipeline = kpipeline
72
+ pipeline_voice_female = pipeline.load_voice(FEMALE_VOICE)
73
+ pipeline_voice_male = pipeline.load_voice(MALE_VOICE)
74
 
75
+ for line in lines:
 
 
 
 
 
76
  if stop_signal.is_set():
77
  break
78
+
79
+ # Expect "[S1] ..." or "[S2] ..."
80
+ if line.startswith("[S1]"):
81
+ pipeline_voice = pipeline_voice_male
82
+ voice = MALE_VOICE
83
+ utterance = line[len("[S1]"):].strip()
84
+ elif line.startswith("[S2]"):
85
+ pipeline_voice = pipeline_voice_female
86
+ voice = FEMALE_VOICE
87
+ utterance = line[len("[S2]"):].strip()
88
+ else: # fallback
89
+ pipeline_voice = pipeline_voice_female
90
+ voice = FEMALE_VOICE
91
+ utterance = line
92
+
93
+ first = True
94
+ for _, ps, _ in pipeline(utterance, voice, speed):
95
+ ref_s = pipeline_voice[len(ps) - 1]
96
+ audio = kmodel(ps, ref_s, speed)
97
+ audio_queue.put((24000, audio.numpy()))
98
+ audio_numpy = audio.numpy()
99
+ print("GENERATED AUDIO", audio_numpy[-100:], audio_numpy.max())
100
+ if first:
101
+ first = False
102
+ audio_queue.put((24000, torch.zeros(1).numpy()))
103
+ audio_queue.put(None) # Signal end of stream
104
+
105
+
106
+ def stream_audio_generator(podcast_text: str):
107
  stop_signal.clear()
108
  threading.Thread(target=process_audio_chunks, args=(podcast_text,)).start()
109
 
 
111
  chunk = audio_queue.get()
112
  if chunk is None:
113
  break
114
+ print("CHUNK", chunk, type(chunk))
115
+ sr, data = chunk
116
 
 
117
  buf = io.BytesIO()
118
+ sf.write(buf, data, sr, format="wav")
119
  buf.seek(0)
120
+ yield buf.getvalue()
121
+
 
122
 
123
 
124
  def stop_generation():
 
127
 
128
 
129
  def generate_podcast():
130
+ return generate_podcast_text(PODCAST_SUBJECT)
 
131
 
132
 
133
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
136
  with gr.Row():
137
  with gr.Column(scale=2):
138
  gr.Markdown(f"## Current Topic: {PODCAST_SUBJECT}")
139
+ gr.Markdown(
140
+ "This app generates a podcast discussion between two hosts about the specified topic."
141
+ )
142
 
143
  generate_btn = gr.Button("Generate Podcast Script", variant="primary")
144
  podcast_output = gr.Textbox(label="Generated Podcast Script", lines=15)