Update app.py
Browse files
app.py
CHANGED
@@ -4,24 +4,34 @@ import gradio as gr
|
|
4 |
import torch
|
5 |
import torchaudio
|
6 |
import google.generativeai as genai
|
7 |
-
from
|
8 |
import numpy as np
|
9 |
import os
|
10 |
-
import json
|
11 |
|
12 |
# Initialize Gemini AI
|
13 |
genai.configure(api_key='YOUR_GEMINI_API_KEY')
|
14 |
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
|
15 |
|
16 |
-
# Initialize
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def generate_podcast_script(content, duration):
|
27 |
prompt = f"""
|
@@ -42,9 +52,15 @@ def generate_podcast_script(content, duration):
|
|
42 |
return response.text
|
43 |
|
44 |
def text_to_speech(text, speaker_id):
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
def create_podcast(content, duration, voice1, voice2):
|
50 |
script = generate_podcast_script(content, duration)
|
@@ -53,10 +69,10 @@ def create_podcast(content, duration, voice1, voice2):
|
|
53 |
|
54 |
for line in lines:
|
55 |
if line.startswith("Host 1:"):
|
56 |
-
audio = text_to_speech(line[7:], speaker_id=0)
|
57 |
audio_segments.append(audio)
|
58 |
elif line.startswith("Host 2:"):
|
59 |
-
audio = text_to_speech(line[7:], speaker_id=1)
|
60 |
audio_segments.append(audio)
|
61 |
|
62 |
# Concatenate audio segments
|
|
|
4 |
import torch
|
5 |
import torchaudio
|
6 |
import google.generativeai as genai
|
7 |
+
from e2_tts_pytorch import E2TTS, DurationPredictor
|
8 |
import numpy as np
|
9 |
import os
|
|
|
10 |
|
11 |
# Initialize Gemini AI
|
12 |
genai.configure(api_key='YOUR_GEMINI_API_KEY')
|
13 |
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
|
14 |
|
15 |
+
# Initialize E2-TTS model
|
16 |
+
duration_predictor = DurationPredictor(
|
17 |
+
transformer=dict(
|
18 |
+
dim=512,
|
19 |
+
depth=8,
|
20 |
+
)
|
21 |
+
)
|
22 |
|
23 |
+
e2tts = E2TTS(
|
24 |
+
duration_predictor=duration_predictor,
|
25 |
+
transformer=dict(
|
26 |
+
dim=512,
|
27 |
+
depth=8
|
28 |
+
),
|
29 |
+
)
|
30 |
+
|
31 |
+
# Load the pre-trained model
|
32 |
+
model_path = "ckpts/E2TTS_Base/model_1200000.safetensors"
|
33 |
+
e2tts.load_state_dict(torch.load(model_path))
|
34 |
+
e2tts.eval()
|
35 |
|
36 |
def generate_podcast_script(content, duration):
|
37 |
prompt = f"""
|
|
|
52 |
return response.text
|
53 |
|
54 |
def text_to_speech(text, speaker_id):
|
55 |
+
# For simplicity, we'll use a random mel spectrogram as input
|
56 |
+
# In a real scenario, you'd use the actual mel spectrogram from the cloned voice
|
57 |
+
mel = torch.randn(1, 80, 100)
|
58 |
+
|
59 |
+
# Generate speech
|
60 |
+
with torch.no_grad():
|
61 |
+
sampled = e2tts.sample(mel[:, :5], text=[text])
|
62 |
+
|
63 |
+
return sampled.cpu().numpy()
|
64 |
|
65 |
def create_podcast(content, duration, voice1, voice2):
|
66 |
script = generate_podcast_script(content, duration)
|
|
|
69 |
|
70 |
for line in lines:
|
71 |
if line.startswith("Host 1:"):
|
72 |
+
audio = text_to_speech(line[7:], speaker_id=0)
|
73 |
audio_segments.append(audio)
|
74 |
elif line.startswith("Host 2:"):
|
75 |
+
audio = text_to_speech(line[7:], speaker_id=1)
|
76 |
audio_segments.append(audio)
|
77 |
|
78 |
# Concatenate audio segments
|