Update app.py
Browse files
app.py
CHANGED
@@ -10,58 +10,7 @@ import os
|
|
10 |
import requests
|
11 |
from tqdm import tqdm
|
12 |
|
13 |
-
#
|
14 |
-
def download_model(url, filename):
|
15 |
-
response = requests.get(url, stream=True)
|
16 |
-
total_size = int(response.headers.get('content-length', 0))
|
17 |
-
block_size = 1024 # 1 KB
|
18 |
-
progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
|
19 |
-
|
20 |
-
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
21 |
-
with open(filename, 'wb') as file:
|
22 |
-
for data in response.iter_content(block_size):
|
23 |
-
size = file.write(data)
|
24 |
-
progress_bar.update(size)
|
25 |
-
progress_bar.close()
|
26 |
-
|
27 |
-
# Check if model file exists, if not, download it
|
28 |
-
model_path = "ckpts/E2TTS_Base/model_1200000.pt"
|
29 |
-
if not os.path.exists(model_path):
|
30 |
-
print("Downloading model file...")
|
31 |
-
model_url = "https://huggingface.co/SWivid/E2-TTS/resolve/main/E2TTS_Base/model_1200000.pt"
|
32 |
-
download_model(model_url, model_path)
|
33 |
-
print("Model file downloaded successfully.")
|
34 |
-
|
35 |
-
# Initialize E2-TTS model
|
36 |
-
duration_predictor = DurationPredictor(
|
37 |
-
transformer=dict(
|
38 |
-
dim=512,
|
39 |
-
depth=8,
|
40 |
-
)
|
41 |
-
)
|
42 |
-
|
43 |
-
e2tts = E2TTS(
|
44 |
-
duration_predictor=duration_predictor,
|
45 |
-
transformer=dict(
|
46 |
-
dim=512,
|
47 |
-
depth=8
|
48 |
-
),
|
49 |
-
)
|
50 |
-
|
51 |
-
# Load the pre-trained model
|
52 |
-
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
|
53 |
-
if 'model_state_dict' in checkpoint:
|
54 |
-
state_dict = checkpoint['model_state_dict']
|
55 |
-
elif 'ema_model_state_dict' in checkpoint:
|
56 |
-
state_dict = checkpoint['ema_model_state_dict']
|
57 |
-
else:
|
58 |
-
state_dict = checkpoint # Assume the checkpoint is the state dict itself
|
59 |
-
|
60 |
-
# Filter out unexpected keys
|
61 |
-
model_dict = e2tts.state_dict()
|
62 |
-
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_dict}
|
63 |
-
e2tts.load_state_dict(filtered_state_dict, strict=False)
|
64 |
-
e2tts.eval()
|
65 |
|
66 |
def generate_podcast_script(api_key, content, duration):
|
67 |
genai.configure(api_key=api_key)
|
@@ -93,24 +42,11 @@ def text_to_speech(text, speaker_id):
|
|
93 |
with torch.no_grad():
|
94 |
sampled = e2tts.sample(mel[:, :5], text=[text])
|
95 |
|
96 |
-
return sampled.cpu().numpy()
|
97 |
|
98 |
def create_podcast(api_key, content, duration, voice1, voice2):
|
99 |
script = generate_podcast_script(api_key, content, duration)
|
100 |
-
|
101 |
-
audio_segments = []
|
102 |
-
|
103 |
-
for line in lines:
|
104 |
-
if line.startswith("Host 1:"):
|
105 |
-
audio = text_to_speech(line[7:], speaker_id=0)
|
106 |
-
audio_segments.append(audio)
|
107 |
-
elif line.startswith("Host 2:"):
|
108 |
-
audio = text_to_speech(line[7:], speaker_id=1)
|
109 |
-
audio_segments.append(audio)
|
110 |
-
|
111 |
-
# Concatenate audio segments
|
112 |
-
podcast_audio = np.concatenate(audio_segments)
|
113 |
-
return (22050, podcast_audio) # Assuming 22050 Hz sample rate
|
114 |
|
115 |
def gradio_interface(api_key, content, duration, voice1, voice2):
|
116 |
script = generate_podcast_script(api_key, content, duration)
|
@@ -128,8 +64,12 @@ def render_podcast(api_key, script, voice1, voice2):
|
|
128 |
audio = text_to_speech(line[7:], speaker_id=1)
|
129 |
audio_segments.append(audio)
|
130 |
|
|
|
|
|
|
|
|
|
131 |
podcast_audio = np.concatenate(audio_segments)
|
132 |
-
return (22050, podcast_audio)
|
133 |
|
134 |
# Gradio Interface
|
135 |
with gr.Blocks() as demo:
|
|
|
10 |
import requests
|
11 |
from tqdm import tqdm
|
12 |
|
13 |
+
# (Keep the model loading and initialization code as before)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
def generate_podcast_script(api_key, content, duration):
|
16 |
genai.configure(api_key=api_key)
|
|
|
42 |
with torch.no_grad():
|
43 |
sampled = e2tts.sample(mel[:, :5], text=[text])
|
44 |
|
45 |
+
return sampled.cpu().numpy().squeeze()
|
46 |
|
47 |
def create_podcast(api_key, content, duration, voice1, voice2):
|
48 |
script = generate_podcast_script(api_key, content, duration)
|
49 |
+
return render_podcast(api_key, script, voice1, voice2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def gradio_interface(api_key, content, duration, voice1, voice2):
|
52 |
script = generate_podcast_script(api_key, content, duration)
|
|
|
64 |
audio = text_to_speech(line[7:], speaker_id=1)
|
65 |
audio_segments.append(audio)
|
66 |
|
67 |
+
if not audio_segments:
|
68 |
+
return (22050, np.zeros(22050)) # Return silence if no audio was generated
|
69 |
+
|
70 |
+
# Concatenate audio segments
|
71 |
podcast_audio = np.concatenate(audio_segments)
|
72 |
+
return (22050, podcast_audio) # Assuming 22050 Hz sample rate
|
73 |
|
74 |
# Gradio Interface
|
75 |
with gr.Blocks() as demo:
|