open-notebooklm / app.py
m-ric's picture
m-ric HF Staff
Working Nari labs code
5da485d
raw
history blame
6.18 kB
import queue
import threading
import spaces
import os
import gradio as gr
from dia.model import Dia
from huggingface_hub import InferenceClient
import numpy as np
from transformers import set_seed
import io, soundfile as sf
# Hardcoded podcast subject
PODCAST_SUBJECT = "The future of AI and its impact on society"
# Initialize the inference client
client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", provider="cerebras", token=os.getenv("HF_TOKEN"))
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="bfloat16")
# Queue for audio streaming
audio_queue = queue.Queue()
stop_signal = threading.Event()
def generate_podcast_text(subject):
prompt = f"""Generate a podcast told by 2 hosts about {subject}.
The podcast should be an insightful discussion, with some amount of playful banter.
Separate dialog as follows using [S1] for the male host and [S2] for the female host, for instance:
[S1] Hello, how are you?
[S2] I'm good, thank you. How are you?
[S1] I'm good, thank you. (laughs)
[S2] Great.
Now go on, make 5 minutes of podcast.
"""
response = client.chat_completion([{"role": "user", "content": prompt}], max_tokens=1000)
return response.choices[0].message.content
def split_podcast_into_chunks(podcast_text, chunk_size=3):
lines = podcast_text.strip().split("\n")
return ["\n".join(lines[i : i + chunk_size]) for i in range(0, len(lines), chunk_size)]
def postprocess_audio(output_audio_np, speed_factor: float=0.8):
"""Taken from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py"""
# Get sample rate from the loaded DAC model
output_sr = 44100
# --- Slow down audio ---
original_len = len(output_audio_np)
# Ensure speed_factor is positive and not excessively small/large to avoid issues
speed_factor = max(0.1, min(speed_factor, 5.0))
target_len = int(
original_len / speed_factor
) # Target length based on speed_factor
if (
target_len != original_len and target_len > 0
): # Only interpolate if length changes and is valid
x_original = np.arange(original_len)
x_resampled = np.linspace(0, original_len - 1, target_len)
resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np)
output_audio = (
output_sr,
resampled_audio_np.astype(np.float32),
) # Use resampled audio
print(
f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed."
)
else:
output_audio = (
output_sr,
output_audio_np,
) # Keep original if calculation fails or no change
print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).")
# --- End slowdown ---
print(
f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
)
# Explicitly convert to int16 to prevent Gradio warning
if (
output_audio[1].dtype == np.float32
or output_audio[1].dtype == np.float64
):
audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0)
audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16)
output_audio = (output_sr, audio_for_gradio)
print("Converted audio to int16 for Gradio output.")
return output_audio
@spaces.GPU
def process_audio_chunks(podcast_text):
chunks = split_podcast_into_chunks(podcast_text)
sample_rate = 44100 # Modified from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py has 44100
for chunk in chunks:
print(f"Processing chunk: {chunk}")
if stop_signal.is_set():
break
set_seed(42)
raw_audio = model.generate(
chunk,
use_torch_compile=False, # To avoid gradio instability
verbose=False,
temperature=1.3,
top_p=0.95,
)
audio_chunk_np = np.array(raw_audio, dtype=np.float32)
audio_queue.put(postprocess_audio(audio_chunk_np))
audio_queue.put(None)
def stream_audio_generator(podcast_text):
"""Creates a generator that yields audio chunks for streaming"""
stop_signal.clear()
threading.Thread(target=process_audio_chunks, args=(podcast_text,)).start()
while True:
chunk = audio_queue.get()
if chunk is None:
break
sr, data = chunk # the tuple you produced earlier
# Encode the numpy array into a WAV blob
buf = io.BytesIO()
sf.write(buf, data.astype(np.float32) / 32768.0, sr, format="wav")
buf.seek(0)
buffer = buf.getvalue()
print("PRINTING BUFFER:", buffer)
yield buffer# <-- bytes, so the browser can play it
def stop_generation():
stop_signal.set()
return "Generation stopped"
def generate_podcast():
podcast_text = generate_podcast_text(PODCAST_SUBJECT)
return podcast_text
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# NotebookLM Podcast Generator")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown(f"## Current Topic: {PODCAST_SUBJECT}")
gr.Markdown("This app generates a podcast discussion between two hosts about the specified topic.")
generate_btn = gr.Button("Generate Podcast Script", variant="primary")
podcast_output = gr.Textbox(label="Generated Podcast Script", lines=15)
gr.Markdown("## Audio Preview")
gr.Markdown("Click below to hear the podcast with realistic voices:")
with gr.Row():
start_audio_btn = gr.Button("▶️ Generate Podcast", variant="secondary")
stop_btn = gr.Button("⏹️ Stop", variant="stop")
audio_output = gr.Audio(label="Podcast Audio", streaming=True)
status_text = gr.Textbox(label="Status", visible=True)
generate_btn.click(fn=generate_podcast, outputs=podcast_output)
start_audio_btn.click(fn=stream_audio_generator, inputs=podcast_output, outputs=audio_output)
stop_btn.click(fn=stop_generation, outputs=status_text)
if __name__ == "__main__":
demo.queue().launch()