File size: 5,447 Bytes
74245b5
 
 
 
 
 
2a527b6
74245b5
 
e07f35b
 
74245b5
 
 
04896a2
74245b5
e07f35b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425f9fe
e07f35b
 
425f9fe
e07f35b
 
 
2a527b6
 
 
 
 
 
 
74245b5
2a527b6
 
 
 
 
 
 
 
425f9fe
c06bbb8
 
 
 
 
 
 
 
 
 
 
 
2a527b6
74245b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a527b6
 
 
 
 
 
 
 
 
74245b5
 
 
 
 
 
 
 
2a527b6
74245b5
 
2a527b6
74245b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# app.py

import gradio as gr
import torch
import torchaudio
import google.generativeai as genai
from e2_tts_pytorch import E2TTS, DurationPredictor
import numpy as np
import os
import requests
from tqdm import tqdm

# Initialize Gemini AI
genai.configure(api_key='YOUR_GEMINI_API_KEY')
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')

# Function to download the model file
def download_model(url, filename):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024  # 1 KB
    progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
    
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, 'wb') as file:
        for data in response.iter_content(block_size):
            size = file.write(data)
            progress_bar.update(size)
    progress_bar.close()

# Check if model file exists, if not, download it
model_path = "ckpts/E2TTS_Base/model_1200000.pt"
if not os.path.exists(model_path):
    print("Downloading model file...")
    model_url = "https://huggingface.co/SWivid/E2-TTS/resolve/main/E2TTS_Base/model_1200000.pt"
    download_model(model_url, model_path)
    print("Model file downloaded successfully.")

# Initialize E2-TTS model
duration_predictor = DurationPredictor(
    transformer=dict(
        dim=512,
        depth=8,
    )
)

e2tts = E2TTS(
    duration_predictor=duration_predictor,
    transformer=dict(
        dim=512,
        depth=8
    ),
)

# Load the pre-trained model
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
if 'model_state_dict' in checkpoint:
    state_dict = checkpoint['model_state_dict']
elif 'ema_model_state_dict' in checkpoint:
    state_dict = checkpoint['ema_model_state_dict']
else:
    state_dict = checkpoint  # Assume the checkpoint is the state dict itself

# Filter out unexpected keys
model_dict = e2tts.state_dict()
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_dict}
e2tts.load_state_dict(filtered_state_dict, strict=False)
e2tts.eval()

def generate_podcast_script(content, duration):
    prompt = f"""
    Create a podcast script for two people discussing the following content:
    {content}
    
    The podcast should last approximately {duration}. Include natural speech patterns,
    humor, and occasional off-topic chit-chat. Use speech fillers like "um", "ah",
    "yes", "I see", "Ok now". Vary the emotional tone (e.g., regular, happy, sad, surprised)
    and indicate these in [square brackets]. Format the script as follows:

    Host 1: [emotion] Dialog
    Host 2: [emotion] Dialog
    
    Ensure the conversation flows naturally and stays relevant to the topic.
    """
    response = model.generate_content(prompt)
    return response.text

def text_to_speech(text, speaker_id):
    # For simplicity, we'll use a random mel spectrogram as input
    # In a real scenario, you'd use the actual mel spectrogram from the cloned voice
    mel = torch.randn(1, 80, 100)
    
    # Generate speech
    with torch.no_grad():
        sampled = e2tts.sample(mel[:, :5], text=[text])
    
    return sampled.cpu().numpy()

def create_podcast(content, duration, voice1, voice2):
    script = generate_podcast_script(content, duration)
    lines = script.split('\n')
    audio_segments = []
    
    for line in lines:
        if line.startswith("Host 1:"):
            audio = text_to_speech(line[7:], speaker_id=0)
            audio_segments.append(audio)
        elif line.startswith("Host 2:"):
            audio = text_to_speech(line[7:], speaker_id=1)
            audio_segments.append(audio)
    
    # Concatenate audio segments
    podcast_audio = np.concatenate(audio_segments)
    return (22050, podcast_audio)  # Assuming 22050 Hz sample rate

def gradio_interface(content, duration, voice1, voice2):
    script = generate_podcast_script(content, duration)
    return script

def render_podcast(script, voice1, voice2):
    lines = script.split('\n')
    audio_segments = []
    
    for line in lines:
        if line.startswith("Host 1:"):
            audio = text_to_speech(line[7:], speaker_id=0)
            audio_segments.append(audio)
        elif line.startswith("Host 2:"):
            audio = text_to_speech(line[7:], speaker_id=1)
            audio_segments.append(audio)
    
    podcast_audio = np.concatenate(audio_segments)
    return (22050, podcast_audio)

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# AI Podcast Generator")
    
    with gr.Row():
        content_input = gr.Textbox(label="Paste your content or upload a document")
        document_upload = gr.File(label="Upload Document")
    
    duration = gr.Radio(["1-5 min", "5-10 min", "10-15 min"], label="Estimated podcast duration")
    
    with gr.Row():
        voice1_upload = gr.Audio(label="Upload Voice 1", type="filepath")
        voice2_upload = gr.Audio(label="Upload Voice 2", type="filepath")
    
    generate_btn = gr.Button("Generate Script")
    script_output = gr.Textbox(label="Generated Script", lines=10)
    
    render_btn = gr.Button("Render Podcast")
    audio_output = gr.Audio(label="Generated Podcast")
    
    generate_btn.click(gradio_interface, inputs=[content_input, duration, voice1_upload, voice2_upload], outputs=script_output)
    render_btn.click(render_podcast, inputs=[script_output, voice1_upload, voice2_upload], outputs=audio_output)

demo.launch()