File size: 5,512 Bytes
1a6d10d
 
886bd4b
992fb70
4af8987
 
1a6d10d
d065cac
4af8987
d36d36b
4af8987
 
1a6d10d
4af8987
 
 
1a6d10d
 
4af8987
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a6d10d
4af8987
 
 
 
 
 
 
 
 
1a6d10d
 
 
4af8987
 
 
1a6d10d
 
 
 
 
 
 
df2821e
1a6d10d
4af8987
 
 
886bd4b
4af8987
886bd4b
4af8987
 
 
 
886bd4b
4af8987
 
 
886bd4b
4af8987
1a6d10d
 
4af8987
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a6d10d
5da485d
1a6d10d
5da485d
 
 
 
4af8987
 
5da485d
 
4af8987
5da485d
4af8987
 
1a6d10d
 
 
 
 
 
 
 
4af8987
1a6d10d
 
 
 
 
 
 
 
4af8987
 
 
1a6d10d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import queue
import threading
import spaces
import os
import io
import soundfile as sf
import gradio as gr
import numpy as np
import torch
from transformers import set_seed
from huggingface_hub import InferenceClient
from kokoro import KModel, KPipeline

# -----------------------------------------------------------------------------
# Hard‑coded podcast subject
# -----------------------------------------------------------------------------
PODCAST_SUBJECT = "The future of AI and its impact on society"

# -----------------------------------------------------------------------------
# LLM that writes the script (unchanged)
# -----------------------------------------------------------------------------
client = InferenceClient(
    "meta-llama/Llama-3.3-70B-Instruct",
    provider="cerebras",
    token=os.getenv("HF_TOKEN"),
)

# -----------------------------------------------------------------------------
# Kokoro TTS setup (replaces Dia)
# -----------------------------------------------------------------------------
CUDA_AVAILABLE = torch.cuda.is_available()

kmodel = KModel().to("cuda" if CUDA_AVAILABLE else "cpu").eval()
kpipeline = KPipeline(lang_code="a")  # English voices

MALE_VOICE = "am_michael"  # [S1]
FEMALE_VOICE = "af_heart"  # [S2]

# Pre‑warm voices to avoid first‑call latency
for v in (MALE_VOICE, FEMALE_VOICE):
    kpipeline.load_voice(v)


audio_queue: queue.Queue[tuple[int, np.ndarray] | None] = queue.Queue()
stop_signal = threading.Event()



def generate_podcast_text(subject: str) -> str:
    """Ask the LLM for a ~5‑minute two‑host script."""
    prompt = f"""Generate a podcast told by 2 hosts about {subject}.
The podcast should be an insightful discussion, with some amount of playful banter.
Separate dialog as follows using [S1] for the male host and [S2] for the female host, for instance:
[S1] Hello, how are you?
[S2] I'm good, thank you. How are you?
[S1] I'm good, thank you. (laughs)
[S2] Great.
Now go on, make 5 minutes of podcast.
"""
    response = client.chat_completion(
        [{"role": "user", "content": prompt}],
        max_tokens=1000,
    )
    return response.choices[0].message.content

@spaces.GPU
def process_audio_chunks(podcast_text: str, speed: float = 1.0) -> None:
    """Read each line, pick voice via tag, send chunks to the queue."""
    lines = [l for l in podcast_text.strip().splitlines() if l.strip()]

    pipeline = kpipeline
    pipeline_voice_female = pipeline.load_voice(FEMALE_VOICE)
    pipeline_voice_male = pipeline.load_voice(MALE_VOICE)

    for line in lines:
        if stop_signal.is_set():
            break

        # Expect "[S1] ..." or "[S2] ..."
        if line.startswith("[S1]"):
            pipeline_voice = pipeline_voice_male
            voice = MALE_VOICE
            utterance = line[len("[S1]"):].strip()
        elif line.startswith("[S2]"):
            pipeline_voice = pipeline_voice_female
            voice = FEMALE_VOICE
            utterance = line[len("[S2]"):].strip()
        else:  # fallback
            pipeline_voice = pipeline_voice_female
            voice = FEMALE_VOICE
            utterance = line

        first = True
        for _, ps, _ in pipeline(utterance, voice, speed):
            ref_s = pipeline_voice[len(ps) - 1]
            audio = kmodel(ps, ref_s, speed)
            audio_queue.put((24000, audio.numpy()))
            audio_numpy = audio.numpy()
            print("GENERATED AUDIO", audio_numpy[-100:], audio_numpy.max())
            if first:
                first = False
                audio_queue.put((24000, torch.zeros(1).numpy()))
    audio_queue.put(None)  # Signal end of stream


def stream_audio_generator(podcast_text: str):
    stop_signal.clear()
    threading.Thread(target=process_audio_chunks, args=(podcast_text,)).start()

    while True:
        chunk = audio_queue.get()
        if chunk is None:
            break
        print("CHUNK", chunk, type(chunk))
        sr, data = chunk

        buf = io.BytesIO()
        sf.write(buf, data, sr, format="wav")
        buf.seek(0)
        yield buf.getvalue()



def stop_generation():
    stop_signal.set()
    return "Generation stopped"


def generate_podcast():
    return generate_podcast_text(PODCAST_SUBJECT)


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# NotebookLM Podcast Generator")

    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown(f"## Current Topic: {PODCAST_SUBJECT}")
            gr.Markdown(
                "This app generates a podcast discussion between two hosts about the specified topic."
            )

            generate_btn = gr.Button("Generate Podcast Script", variant="primary")
            podcast_output = gr.Textbox(label="Generated Podcast Script", lines=15)

            gr.Markdown("## Audio Preview")
            gr.Markdown("Click below to hear the podcast with realistic voices:")

            with gr.Row():
                start_audio_btn = gr.Button("▶️ Generate Podcast", variant="secondary")
                stop_btn = gr.Button("⏹️ Stop", variant="stop")

            audio_output = gr.Audio(label="Podcast Audio", streaming=True)
            status_text = gr.Textbox(label="Status", visible=True)

    generate_btn.click(fn=generate_podcast, outputs=podcast_output)

    start_audio_btn.click(fn=stream_audio_generator, inputs=podcast_output, outputs=audio_output)
    stop_btn.click(fn=stop_generation, outputs=status_text)

if __name__ == "__main__":
    demo.queue().launch()