File size: 6,175 Bytes
1a6d10d
 
886bd4b
992fb70
1a6d10d
 
 
d065cac
d36d36b
5da485d
 
1a6d10d
 
 
 
 
f0bd7e5
5da485d
1a6d10d
 
 
 
 
 
 
 
 
 
 
 
 
 
df2821e
1a6d10d
 
 
 
 
86423a4
1a6d10d
5da485d
1a6d10d
5da485d
886bd4b
 
 
1a6d10d
886bd4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a6d10d
 
0d89a98
1a6d10d
5da485d
1a6d10d
 
d36d36b
886bd4b
 
d0b77d7
886bd4b
d0b77d7
 
886bd4b
 
 
1a6d10d
 
 
 
 
 
 
5da485d
1a6d10d
5da485d
 
 
 
 
 
 
 
 
 
 
 
 
1a6d10d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import queue
import threading
import spaces
import os
import gradio as gr
from dia.model import Dia
from huggingface_hub import InferenceClient
import numpy as np
from transformers import set_seed
import io, soundfile as sf


# Hardcoded podcast subject
PODCAST_SUBJECT = "The future of AI and its impact on society"

# Initialize the inference client
client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", provider="cerebras", token=os.getenv("HF_TOKEN"))
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="bfloat16")

# Queue for audio streaming
audio_queue = queue.Queue()
stop_signal = threading.Event()


def generate_podcast_text(subject):
    prompt = f"""Generate a podcast told by 2 hosts about {subject}.
The podcast should be an insightful discussion, with some amount of playful banter.
Separate dialog as follows using [S1] for the male host and [S2] for the female host, for instance:
[S1] Hello, how are you?
[S2] I'm good, thank you. How are you?
[S1] I'm good, thank you. (laughs)
[S2] Great.
Now go on, make 5 minutes of podcast.
"""
    response = client.chat_completion([{"role": "user", "content": prompt}], max_tokens=1000)
    return response.choices[0].message.content


def split_podcast_into_chunks(podcast_text, chunk_size=3):
    lines = podcast_text.strip().split("\n")
    return ["\n".join(lines[i : i + chunk_size]) for i in range(0, len(lines), chunk_size)]

def postprocess_audio(output_audio_np, speed_factor: float=0.8):
    """Taken from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py"""
    # Get sample rate from the loaded DAC model
    output_sr = 44100

    # --- Slow down audio ---
    original_len = len(output_audio_np)
    # Ensure speed_factor is positive and not excessively small/large to avoid issues
    speed_factor = max(0.1, min(speed_factor, 5.0))
    target_len = int(
        original_len / speed_factor
    )  # Target length based on speed_factor
    if (
        target_len != original_len and target_len > 0
    ):  # Only interpolate if length changes and is valid
        x_original = np.arange(original_len)
        x_resampled = np.linspace(0, original_len - 1, target_len)
        resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np)
        output_audio = (
            output_sr,
            resampled_audio_np.astype(np.float32),
        )  # Use resampled audio
        print(
            f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed."
        )
    else:
        output_audio = (
            output_sr,
            output_audio_np,
        )  # Keep original if calculation fails or no change
        print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).")
    # --- End slowdown ---

    print(
        f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
    )

    # Explicitly convert to int16 to prevent Gradio warning
    if (
        output_audio[1].dtype == np.float32
        or output_audio[1].dtype == np.float64
    ):
        audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0)
        audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16)
        output_audio = (output_sr, audio_for_gradio)
        print("Converted audio to int16 for Gradio output.")
    return output_audio


@spaces.GPU
def process_audio_chunks(podcast_text):
    chunks = split_podcast_into_chunks(podcast_text)
    sample_rate = 44100 # Modified from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py has 44100
    for chunk in chunks:
        print(f"Processing chunk: {chunk}")
        if stop_signal.is_set():
            break
        set_seed(42)
        raw_audio = model.generate(
            chunk,
            use_torch_compile=False, # To avoid gradio instability
            verbose=False,
            temperature=1.3,
            top_p=0.95,
        )
        audio_chunk_np = np.array(raw_audio, dtype=np.float32)
        audio_queue.put(postprocess_audio(audio_chunk_np))

    audio_queue.put(None)


def stream_audio_generator(podcast_text):
    """Creates a generator that yields audio chunks for streaming"""
    stop_signal.clear()
    threading.Thread(target=process_audio_chunks, args=(podcast_text,)).start()

    while True:
        chunk = audio_queue.get()
        if chunk is None:
            break
        sr, data = chunk           # the tuple you produced earlier

        # Encode the numpy array into a WAV blob
        buf = io.BytesIO()
        sf.write(buf, data.astype(np.float32) / 32768.0, sr, format="wav")
        buf.seek(0)
        buffer = buf.getvalue()
        print("PRINTING BUFFER:", buffer)
        yield buffer# <-- bytes, so the browser can play it


def stop_generation():
    stop_signal.set()
    return "Generation stopped"


def generate_podcast():
    podcast_text = generate_podcast_text(PODCAST_SUBJECT)
    return podcast_text


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# NotebookLM Podcast Generator")

    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown(f"## Current Topic: {PODCAST_SUBJECT}")
            gr.Markdown("This app generates a podcast discussion between two hosts about the specified topic.")

            generate_btn = gr.Button("Generate Podcast Script", variant="primary")
            podcast_output = gr.Textbox(label="Generated Podcast Script", lines=15)

            gr.Markdown("## Audio Preview")
            gr.Markdown("Click below to hear the podcast with realistic voices:")

            with gr.Row():
                start_audio_btn = gr.Button("▶️ Generate Podcast", variant="secondary")
                stop_btn = gr.Button("⏹️ Stop", variant="stop")

            audio_output = gr.Audio(label="Podcast Audio", streaming=True)
            status_text = gr.Textbox(label="Status", visible=True)

    generate_btn.click(fn=generate_podcast, outputs=podcast_output)

    start_audio_btn.click(fn=stream_audio_generator, inputs=podcast_output, outputs=audio_output)
    stop_btn.click(fn=stop_generation, outputs=status_text)

if __name__ == "__main__":
    demo.queue().launch()