Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,175 Bytes
1a6d10d 886bd4b 992fb70 1a6d10d d065cac d36d36b 5da485d 1a6d10d f0bd7e5 5da485d 1a6d10d df2821e 1a6d10d 86423a4 1a6d10d 5da485d 1a6d10d 5da485d 886bd4b 1a6d10d 886bd4b 1a6d10d 0d89a98 1a6d10d 5da485d 1a6d10d d36d36b 886bd4b d0b77d7 886bd4b d0b77d7 886bd4b 1a6d10d 5da485d 1a6d10d 5da485d 1a6d10d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import queue
import threading
import spaces
import os
import gradio as gr
from dia.model import Dia
from huggingface_hub import InferenceClient
import numpy as np
from transformers import set_seed
import io, soundfile as sf
# Hardcoded podcast subject
PODCAST_SUBJECT = "The future of AI and its impact on society"
# Initialize the inference client
client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", provider="cerebras", token=os.getenv("HF_TOKEN"))
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="bfloat16")
# Queue for audio streaming
audio_queue = queue.Queue()
stop_signal = threading.Event()
def generate_podcast_text(subject):
prompt = f"""Generate a podcast told by 2 hosts about {subject}.
The podcast should be an insightful discussion, with some amount of playful banter.
Separate dialog as follows using [S1] for the male host and [S2] for the female host, for instance:
[S1] Hello, how are you?
[S2] I'm good, thank you. How are you?
[S1] I'm good, thank you. (laughs)
[S2] Great.
Now go on, make 5 minutes of podcast.
"""
response = client.chat_completion([{"role": "user", "content": prompt}], max_tokens=1000)
return response.choices[0].message.content
def split_podcast_into_chunks(podcast_text, chunk_size=3):
lines = podcast_text.strip().split("\n")
return ["\n".join(lines[i : i + chunk_size]) for i in range(0, len(lines), chunk_size)]
def postprocess_audio(output_audio_np, speed_factor: float=0.8):
"""Taken from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py"""
# Get sample rate from the loaded DAC model
output_sr = 44100
# --- Slow down audio ---
original_len = len(output_audio_np)
# Ensure speed_factor is positive and not excessively small/large to avoid issues
speed_factor = max(0.1, min(speed_factor, 5.0))
target_len = int(
original_len / speed_factor
) # Target length based on speed_factor
if (
target_len != original_len and target_len > 0
): # Only interpolate if length changes and is valid
x_original = np.arange(original_len)
x_resampled = np.linspace(0, original_len - 1, target_len)
resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np)
output_audio = (
output_sr,
resampled_audio_np.astype(np.float32),
) # Use resampled audio
print(
f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed."
)
else:
output_audio = (
output_sr,
output_audio_np,
) # Keep original if calculation fails or no change
print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).")
# --- End slowdown ---
print(
f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
)
# Explicitly convert to int16 to prevent Gradio warning
if (
output_audio[1].dtype == np.float32
or output_audio[1].dtype == np.float64
):
audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0)
audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16)
output_audio = (output_sr, audio_for_gradio)
print("Converted audio to int16 for Gradio output.")
return output_audio
@spaces.GPU
def process_audio_chunks(podcast_text):
chunks = split_podcast_into_chunks(podcast_text)
sample_rate = 44100 # Modified from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py has 44100
for chunk in chunks:
print(f"Processing chunk: {chunk}")
if stop_signal.is_set():
break
set_seed(42)
raw_audio = model.generate(
chunk,
use_torch_compile=False, # To avoid gradio instability
verbose=False,
temperature=1.3,
top_p=0.95,
)
audio_chunk_np = np.array(raw_audio, dtype=np.float32)
audio_queue.put(postprocess_audio(audio_chunk_np))
audio_queue.put(None)
def stream_audio_generator(podcast_text):
"""Creates a generator that yields audio chunks for streaming"""
stop_signal.clear()
threading.Thread(target=process_audio_chunks, args=(podcast_text,)).start()
while True:
chunk = audio_queue.get()
if chunk is None:
break
sr, data = chunk # the tuple you produced earlier
# Encode the numpy array into a WAV blob
buf = io.BytesIO()
sf.write(buf, data.astype(np.float32) / 32768.0, sr, format="wav")
buf.seek(0)
buffer = buf.getvalue()
print("PRINTING BUFFER:", buffer)
yield buffer# <-- bytes, so the browser can play it
def stop_generation():
stop_signal.set()
return "Generation stopped"
def generate_podcast():
podcast_text = generate_podcast_text(PODCAST_SUBJECT)
return podcast_text
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# NotebookLM Podcast Generator")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown(f"## Current Topic: {PODCAST_SUBJECT}")
gr.Markdown("This app generates a podcast discussion between two hosts about the specified topic.")
generate_btn = gr.Button("Generate Podcast Script", variant="primary")
podcast_output = gr.Textbox(label="Generated Podcast Script", lines=15)
gr.Markdown("## Audio Preview")
gr.Markdown("Click below to hear the podcast with realistic voices:")
with gr.Row():
start_audio_btn = gr.Button("▶️ Generate Podcast", variant="secondary")
stop_btn = gr.Button("⏹️ Stop", variant="stop")
audio_output = gr.Audio(label="Podcast Audio", streaming=True)
status_text = gr.Textbox(label="Status", visible=True)
generate_btn.click(fn=generate_podcast, outputs=podcast_output)
start_audio_btn.click(fn=stream_audio_generator, inputs=podcast_output, outputs=audio_output)
stop_btn.click(fn=stop_generation, outputs=status_text)
if __name__ == "__main__":
demo.queue().launch()
|