File size: 5,447 Bytes
74245b5 2a527b6 74245b5 e07f35b 74245b5 04896a2 74245b5 e07f35b 425f9fe e07f35b 425f9fe e07f35b 2a527b6 74245b5 2a527b6 425f9fe c06bbb8 2a527b6 74245b5 2a527b6 74245b5 2a527b6 74245b5 2a527b6 74245b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# app.py
import gradio as gr
import torch
import torchaudio
import google.generativeai as genai
from e2_tts_pytorch import E2TTS, DurationPredictor
import numpy as np
import os
import requests
from tqdm import tqdm
# Initialize Gemini AI
genai.configure(api_key='YOUR_GEMINI_API_KEY')
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
# Function to download the model file
def download_model(url, filename):
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 KB
progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, 'wb') as file:
for data in response.iter_content(block_size):
size = file.write(data)
progress_bar.update(size)
progress_bar.close()
# Check if model file exists, if not, download it
model_path = "ckpts/E2TTS_Base/model_1200000.pt"
if not os.path.exists(model_path):
print("Downloading model file...")
model_url = "https://huggingface.co/SWivid/E2-TTS/resolve/main/E2TTS_Base/model_1200000.pt"
download_model(model_url, model_path)
print("Model file downloaded successfully.")
# Initialize E2-TTS model
duration_predictor = DurationPredictor(
transformer=dict(
dim=512,
depth=8,
)
)
e2tts = E2TTS(
duration_predictor=duration_predictor,
transformer=dict(
dim=512,
depth=8
),
)
# Load the pre-trained model
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
elif 'ema_model_state_dict' in checkpoint:
state_dict = checkpoint['ema_model_state_dict']
else:
state_dict = checkpoint # Assume the checkpoint is the state dict itself
# Filter out unexpected keys
model_dict = e2tts.state_dict()
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_dict}
e2tts.load_state_dict(filtered_state_dict, strict=False)
e2tts.eval()
def generate_podcast_script(content, duration):
prompt = f"""
Create a podcast script for two people discussing the following content:
{content}
The podcast should last approximately {duration}. Include natural speech patterns,
humor, and occasional off-topic chit-chat. Use speech fillers like "um", "ah",
"yes", "I see", "Ok now". Vary the emotional tone (e.g., regular, happy, sad, surprised)
and indicate these in [square brackets]. Format the script as follows:
Host 1: [emotion] Dialog
Host 2: [emotion] Dialog
Ensure the conversation flows naturally and stays relevant to the topic.
"""
response = model.generate_content(prompt)
return response.text
def text_to_speech(text, speaker_id):
# For simplicity, we'll use a random mel spectrogram as input
# In a real scenario, you'd use the actual mel spectrogram from the cloned voice
mel = torch.randn(1, 80, 100)
# Generate speech
with torch.no_grad():
sampled = e2tts.sample(mel[:, :5], text=[text])
return sampled.cpu().numpy()
def create_podcast(content, duration, voice1, voice2):
script = generate_podcast_script(content, duration)
lines = script.split('\n')
audio_segments = []
for line in lines:
if line.startswith("Host 1:"):
audio = text_to_speech(line[7:], speaker_id=0)
audio_segments.append(audio)
elif line.startswith("Host 2:"):
audio = text_to_speech(line[7:], speaker_id=1)
audio_segments.append(audio)
# Concatenate audio segments
podcast_audio = np.concatenate(audio_segments)
return (22050, podcast_audio) # Assuming 22050 Hz sample rate
def gradio_interface(content, duration, voice1, voice2):
script = generate_podcast_script(content, duration)
return script
def render_podcast(script, voice1, voice2):
lines = script.split('\n')
audio_segments = []
for line in lines:
if line.startswith("Host 1:"):
audio = text_to_speech(line[7:], speaker_id=0)
audio_segments.append(audio)
elif line.startswith("Host 2:"):
audio = text_to_speech(line[7:], speaker_id=1)
audio_segments.append(audio)
podcast_audio = np.concatenate(audio_segments)
return (22050, podcast_audio)
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# AI Podcast Generator")
with gr.Row():
content_input = gr.Textbox(label="Paste your content or upload a document")
document_upload = gr.File(label="Upload Document")
duration = gr.Radio(["1-5 min", "5-10 min", "10-15 min"], label="Estimated podcast duration")
with gr.Row():
voice1_upload = gr.Audio(label="Upload Voice 1", type="filepath")
voice2_upload = gr.Audio(label="Upload Voice 2", type="filepath")
generate_btn = gr.Button("Generate Script")
script_output = gr.Textbox(label="Generated Script", lines=10)
render_btn = gr.Button("Render Podcast")
audio_output = gr.Audio(label="Generated Podcast")
generate_btn.click(gradio_interface, inputs=[content_input, duration, voice1_upload, voice2_upload], outputs=script_output)
render_btn.click(render_podcast, inputs=[script_output, voice1_upload, voice2_upload], outputs=audio_output)
demo.launch() |