bluenevus commited on
Commit
ac81409
·
verified ·
1 Parent(s): d0f551e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -68
app.py CHANGED
@@ -10,58 +10,7 @@ import os
10
  import requests
11
  from tqdm import tqdm
12
 
13
- # Function to download the model file
14
- def download_model(url, filename):
15
- response = requests.get(url, stream=True)
16
- total_size = int(response.headers.get('content-length', 0))
17
- block_size = 1024 # 1 KB
18
- progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
19
-
20
- os.makedirs(os.path.dirname(filename), exist_ok=True)
21
- with open(filename, 'wb') as file:
22
- for data in response.iter_content(block_size):
23
- size = file.write(data)
24
- progress_bar.update(size)
25
- progress_bar.close()
26
-
27
- # Check if model file exists, if not, download it
28
- model_path = "ckpts/E2TTS_Base/model_1200000.pt"
29
- if not os.path.exists(model_path):
30
- print("Downloading model file...")
31
- model_url = "https://huggingface.co/SWivid/E2-TTS/resolve/main/E2TTS_Base/model_1200000.pt"
32
- download_model(model_url, model_path)
33
- print("Model file downloaded successfully.")
34
-
35
- # Initialize E2-TTS model
36
- duration_predictor = DurationPredictor(
37
- transformer=dict(
38
- dim=512,
39
- depth=8,
40
- )
41
- )
42
-
43
- e2tts = E2TTS(
44
- duration_predictor=duration_predictor,
45
- transformer=dict(
46
- dim=512,
47
- depth=8
48
- ),
49
- )
50
-
51
- # Load the pre-trained model
52
- checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
53
- if 'model_state_dict' in checkpoint:
54
- state_dict = checkpoint['model_state_dict']
55
- elif 'ema_model_state_dict' in checkpoint:
56
- state_dict = checkpoint['ema_model_state_dict']
57
- else:
58
- state_dict = checkpoint # Assume the checkpoint is the state dict itself
59
-
60
- # Filter out unexpected keys
61
- model_dict = e2tts.state_dict()
62
- filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_dict}
63
- e2tts.load_state_dict(filtered_state_dict, strict=False)
64
- e2tts.eval()
65
 
66
  def generate_podcast_script(api_key, content, duration):
67
  genai.configure(api_key=api_key)
@@ -93,24 +42,11 @@ def text_to_speech(text, speaker_id):
93
  with torch.no_grad():
94
  sampled = e2tts.sample(mel[:, :5], text=[text])
95
 
96
- return sampled.cpu().numpy()
97
 
98
  def create_podcast(api_key, content, duration, voice1, voice2):
99
  script = generate_podcast_script(api_key, content, duration)
100
- lines = script.split('\n')
101
- audio_segments = []
102
-
103
- for line in lines:
104
- if line.startswith("Host 1:"):
105
- audio = text_to_speech(line[7:], speaker_id=0)
106
- audio_segments.append(audio)
107
- elif line.startswith("Host 2:"):
108
- audio = text_to_speech(line[7:], speaker_id=1)
109
- audio_segments.append(audio)
110
-
111
- # Concatenate audio segments
112
- podcast_audio = np.concatenate(audio_segments)
113
- return (22050, podcast_audio) # Assuming 22050 Hz sample rate
114
 
115
  def gradio_interface(api_key, content, duration, voice1, voice2):
116
  script = generate_podcast_script(api_key, content, duration)
@@ -128,8 +64,12 @@ def render_podcast(api_key, script, voice1, voice2):
128
  audio = text_to_speech(line[7:], speaker_id=1)
129
  audio_segments.append(audio)
130
 
 
 
 
 
131
  podcast_audio = np.concatenate(audio_segments)
132
- return (22050, podcast_audio)
133
 
134
  # Gradio Interface
135
  with gr.Blocks() as demo:
 
10
  import requests
11
  from tqdm import tqdm
12
 
13
+ # (Keep the model loading and initialization code as before)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def generate_podcast_script(api_key, content, duration):
16
  genai.configure(api_key=api_key)
 
42
  with torch.no_grad():
43
  sampled = e2tts.sample(mel[:, :5], text=[text])
44
 
45
+ return sampled.cpu().numpy().squeeze()
46
 
47
  def create_podcast(api_key, content, duration, voice1, voice2):
48
  script = generate_podcast_script(api_key, content, duration)
49
+ return render_podcast(api_key, script, voice1, voice2)
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def gradio_interface(api_key, content, duration, voice1, voice2):
52
  script = generate_podcast_script(api_key, content, duration)
 
64
  audio = text_to_speech(line[7:], speaker_id=1)
65
  audio_segments.append(audio)
66
 
67
+ if not audio_segments:
68
+ return (22050, np.zeros(22050)) # Return silence if no audio was generated
69
+
70
+ # Concatenate audio segments
71
  podcast_audio = np.concatenate(audio_segments)
72
+ return (22050, podcast_audio) # Assuming 22050 Hz sample rate
73
 
74
  # Gradio Interface
75
  with gr.Blocks() as demo: