File size: 4,044 Bytes
1a6d10d
 
992fb70
1a6d10d
 
 
d065cac
d36d36b
1a6d10d
 
 
 
 
f0bd7e5
d36d36b
1a6d10d
 
 
 
 
 
 
 
 
 
 
 
 
 
df2821e
1a6d10d
 
 
 
 
86423a4
1a6d10d
 
 
 
 
 
 
 
 
 
 
 
0d89a98
1a6d10d
 
 
d36d36b
5f17aa3
d065cac
 
1a6d10d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d065cac
1a6d10d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import queue
import threading
import os
import gradio as gr
from dia.model import Dia
from huggingface_hub import InferenceClient
import numpy as np
from transformers import set_seed

# Hardcoded podcast subject
PODCAST_SUBJECT = "The future of AI and its impact on society"

# Initialize the inference client
client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", provider="cerebras", token=os.getenv("HF_TOKEN"))
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16")

# Queue for audio streaming
audio_queue = queue.Queue()
stop_signal = threading.Event()


def generate_podcast_text(subject):
    prompt = f"""Generate a podcast told by 2 hosts about {subject}.
The podcast should be an insightful discussion, with some amount of playful banter.
Separate dialog as follows using [S1] for the male host and [S2] for the female host, for instance:
[S1] Hello, how are you?
[S2] I'm good, thank you. How are you?
[S1] I'm good, thank you. (laughs)
[S2] Great.
Now go on, make 5 minutes of podcast.
"""
    response = client.chat_completion([{"role": "user", "content": prompt}], max_tokens=1000)
    return response.choices[0].message.content


def split_podcast_into_chunks(podcast_text, chunk_size=3):
    lines = podcast_text.strip().split("\n")
    chunks = []

    for i in range(0, len(lines), chunk_size):
        chunk = "\n".join(lines[i : i + chunk_size])
        chunks.append(chunk)

    return chunks


def process_audio_chunks(podcast_text):
    chunks = split_podcast_into_chunks(podcast_text)
    sample_rate = 44100 # Modified from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py has 44100
    for chunk in chunks:
        if stop_signal.is_set():
            break
        set_seed(42)
        raw_audio = model.generate(chunk, use_torch_compile=True, verbose=False)
        audio_chunk = np.array(raw_audio, dtype=np.float32)
        audio_queue.put((sample_rate, audio_chunk))

    audio_queue.put(None)


def stream_audio_generator(podcast_text):
    """Creates a generator that yields audio chunks for streaming"""
    stop_signal.clear()

    # Start audio generation in a separate thread
    gen_thread = threading.Thread(target=process_audio_chunks, args=(podcast_text,))
    gen_thread.start()

    try:
        while True:
            # Get next chunk from queue
            chunk = audio_queue.get()

            # None signals end of generation
            if chunk is None:
                break

            # Yield the audio chunk with sample rate
            yield chunk

    except Exception as e:
        print(f"Error in streaming: {e}")


def stop_generation():
    stop_signal.set()
    return "Generation stopped"


def generate_podcast():
    podcast_text = generate_podcast_text(PODCAST_SUBJECT)
    return podcast_text


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# NotebookLM Podcast Generator")

    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown(f"## Current Topic: {PODCAST_SUBJECT}")
            gr.Markdown("This app generates a podcast discussion between two hosts about the specified topic.")

            generate_btn = gr.Button("Generate Podcast Script", variant="primary")
            podcast_output = gr.Textbox(label="Generated Podcast Script", lines=15)

            gr.Markdown("## Audio Preview")
            gr.Markdown("Click below to hear the podcast with realistic voices:")

            with gr.Row():
                start_audio_btn = gr.Button("▶️ Generate Podcast", variant="secondary")
                stop_btn = gr.Button("⏹️ Stop", variant="stop")

            audio_output = gr.Audio(label="Podcast Audio", streaming=True)
            status_text = gr.Textbox(label="Status", visible=True)

    generate_btn.click(fn=generate_podcast, outputs=podcast_output)

    start_audio_btn.click(fn=stream_audio_generator, inputs=podcast_output, outputs=audio_output)
    stop_btn.click(fn=stop_generation, outputs=status_text)

if __name__ == "__main__":
    demo.queue().launch()