bluenevus commited on
Commit
e07f35b
·
verified ·
1 Parent(s): 2a527b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -1
app.py CHANGED
@@ -7,11 +7,35 @@ import google.generativeai as genai
7
  from e2_tts_pytorch import E2TTS, DurationPredictor
8
  import numpy as np
9
  import os
 
 
10
 
11
  # Initialize Gemini AI
12
  genai.configure(api_key='YOUR_GEMINI_API_KEY')
13
  model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Initialize E2-TTS model
16
  duration_predictor = DurationPredictor(
17
  transformer=dict(
@@ -29,7 +53,6 @@ e2tts = E2TTS(
29
  )
30
 
31
  # Load the pre-trained model
32
- model_path = "ckpts/E2TTS_Base/model_1200000.safetensors"
33
  e2tts.load_state_dict(torch.load(model_path))
34
  e2tts.eval()
35
 
 
7
  from e2_tts_pytorch import E2TTS, DurationPredictor
8
  import numpy as np
9
  import os
10
+ import requests
11
+ from tqdm import tqdm
12
 
13
  # Initialize Gemini AI
14
  genai.configure(api_key='YOUR_GEMINI_API_KEY')
15
  model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
16
 
17
+ # Function to download the model file
18
+ def download_model(url, filename):
19
+ response = requests.get(url, stream=True)
20
+ total_size = int(response.headers.get('content-length', 0))
21
+ block_size = 1024 # 1 KB
22
+ progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
23
+
24
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
25
+ with open(filename, 'wb') as file:
26
+ for data in response.iter_content(block_size):
27
+ size = file.write(data)
28
+ progress_bar.update(size)
29
+ progress_bar.close()
30
+
31
+ # Check if model file exists, if not, download it
32
+ model_path = "ckpts/E2TTS_Base/model_1200000.safetensors"
33
+ if not os.path.exists(model_path):
34
+ print("Downloading model file...")
35
+ model_url = "https://huggingface.co/SWivid/E2-TTS/resolve/main/E2TTS_Base/model_1200000.safetensors"
36
+ download_model(model_url, model_path)
37
+ print("Model file downloaded successfully.")
38
+
39
  # Initialize E2-TTS model
40
  duration_predictor = DurationPredictor(
41
  transformer=dict(
 
53
  )
54
 
55
  # Load the pre-trained model
 
56
  e2tts.load_state_dict(torch.load(model_path))
57
  e2tts.eval()
58