bluenevus commited on
Commit
425f9fe
·
verified ·
1 Parent(s): ef2dd7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -9,7 +9,6 @@ import numpy as np
9
  import os
10
  import requests
11
  from tqdm import tqdm
12
- from safetensors import safe_open
13
 
14
  # Initialize Gemini AI
15
  genai.configure(api_key='YOUR_GEMINI_API_KEY')
@@ -30,10 +29,10 @@ def download_model(url, filename):
30
  progress_bar.close()
31
 
32
  # Check if model file exists, if not, download it
33
- model_path = "ckpts/E2TTS_Base/model_1200000.safetensors"
34
  if not os.path.exists(model_path):
35
  print("Downloading model file...")
36
- model_url = "https://huggingface.co/SWivid/E2-TTS/resolve/main/E2TTS_Base/model_1200000.safetensors"
37
  download_model(model_url, model_path)
38
  print("Model file downloaded successfully.")
39
 
@@ -53,10 +52,9 @@ e2tts = E2TTS(
53
  ),
54
  )
55
 
56
- # Load the pre-trained model using safetensors
57
- with safe_open(model_path, framework="pt", device="cpu") as f:
58
- state_dict = {key: f.get_tensor(key) for key in f.keys()}
59
- e2tts.load_state_dict(state_dict)
60
  e2tts.eval()
61
 
62
  def generate_podcast_script(content, duration):
 
9
  import os
10
  import requests
11
  from tqdm import tqdm
 
12
 
13
  # Initialize Gemini AI
14
  genai.configure(api_key='YOUR_GEMINI_API_KEY')
 
29
  progress_bar.close()
30
 
31
  # Check if model file exists, if not, download it
32
+ model_path = "ckpts/E2TTS_Base/model_1200000.pt"
33
  if not os.path.exists(model_path):
34
  print("Downloading model file...")
35
+ model_url = "https://huggingface.co/SWivid/E2-TTS/resolve/main/E2TTS_Base/model_1200000.pt"
36
  download_model(model_url, model_path)
37
  print("Model file downloaded successfully.")
38
 
 
52
  ),
53
  )
54
 
55
+ # Load the pre-trained model
56
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
57
+ e2tts.load_state_dict(state_dict)
 
58
  e2tts.eval()
59
 
60
  def generate_podcast_script(content, duration):