Update app.py
Browse files
app.py
CHANGED
@@ -7,11 +7,35 @@ import google.generativeai as genai
|
|
7 |
from e2_tts_pytorch import E2TTS, DurationPredictor
|
8 |
import numpy as np
|
9 |
import os
|
|
|
|
|
10 |
|
11 |
# Initialize Gemini AI
|
12 |
genai.configure(api_key='YOUR_GEMINI_API_KEY')
|
13 |
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
# Initialize E2-TTS model
|
16 |
duration_predictor = DurationPredictor(
|
17 |
transformer=dict(
|
@@ -29,7 +53,6 @@ e2tts = E2TTS(
|
|
29 |
)
|
30 |
|
31 |
# Load the pre-trained model
|
32 |
-
model_path = "ckpts/E2TTS_Base/model_1200000.safetensors"
|
33 |
e2tts.load_state_dict(torch.load(model_path))
|
34 |
e2tts.eval()
|
35 |
|
|
|
7 |
from e2_tts_pytorch import E2TTS, DurationPredictor
|
8 |
import numpy as np
|
9 |
import os
|
10 |
+
import requests
|
11 |
+
from tqdm import tqdm
|
12 |
|
13 |
# Initialize Gemini AI
|
14 |
genai.configure(api_key='YOUR_GEMINI_API_KEY')
|
15 |
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
|
16 |
|
17 |
+
# Function to download the model file
|
18 |
+
def download_model(url, filename):
|
19 |
+
response = requests.get(url, stream=True)
|
20 |
+
total_size = int(response.headers.get('content-length', 0))
|
21 |
+
block_size = 1024 # 1 KB
|
22 |
+
progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
|
23 |
+
|
24 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
25 |
+
with open(filename, 'wb') as file:
|
26 |
+
for data in response.iter_content(block_size):
|
27 |
+
size = file.write(data)
|
28 |
+
progress_bar.update(size)
|
29 |
+
progress_bar.close()
|
30 |
+
|
31 |
+
# Check if model file exists, if not, download it
|
32 |
+
model_path = "ckpts/E2TTS_Base/model_1200000.safetensors"
|
33 |
+
if not os.path.exists(model_path):
|
34 |
+
print("Downloading model file...")
|
35 |
+
model_url = "https://huggingface.co/SWivid/E2-TTS/resolve/main/E2TTS_Base/model_1200000.safetensors"
|
36 |
+
download_model(model_url, model_path)
|
37 |
+
print("Model file downloaded successfully.")
|
38 |
+
|
39 |
# Initialize E2-TTS model
|
40 |
duration_predictor = DurationPredictor(
|
41 |
transformer=dict(
|
|
|
53 |
)
|
54 |
|
55 |
# Load the pre-trained model
|
|
|
56 |
e2tts.load_state_dict(torch.load(model_path))
|
57 |
e2tts.eval()
|
58 |
|