bluenevus commited on
Commit
aa10e55
·
verified ·
1 Parent(s): eb6d374

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -19
app.py CHANGED
@@ -3,22 +3,21 @@ import google.generativeai as genai
3
  import numpy as np
4
  import re
5
  import torch
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from huggingface_hub import snapshot_download, login
8
  import logging
9
  import os
10
  import spaces
11
  import warnings
12
- import librosa
13
 
14
  # Set up logging
15
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
16
  logger = logging.getLogger(__name__)
17
 
18
  # Suppress specific warnings
19
- warnings.filterwarnings("ignore", category=UserWarning, message="Trying to convert audio automatically from float32 to 16-bit int format.")
20
- warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in divide")
21
- warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in cast")
22
 
23
  def get_device():
24
  if torch.cuda.is_available():
@@ -115,26 +114,32 @@ def text_to_speech(text, voice):
115
  if model is None or tokenizer is None:
116
  load_model()
117
 
118
- inputs = tokenizer(text, return_tensors="pt").to(device)
 
 
 
119
  with torch.no_grad():
120
  output = model.generate(**inputs, max_new_tokens=256)
121
- mel = output[0].cpu().numpy()
122
- audio = mel_to_audio(mel)
123
- return audio
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  except Exception as e:
125
  logger.error(f"Error in text_to_speech: {str(e)}")
126
  raise
127
 
128
- def mel_to_audio(mel):
129
- try:
130
- # Convert mel spectrogram to audio using librosa
131
- audio = librosa.feature.inverse.mel_to_audio(mel, sr=24000, n_iter=10)
132
- return audio
133
- except Exception as e:
134
- logger.error(f"Error in mel_to_audio conversion: {str(e)}")
135
- # Return silence if conversion fails
136
- return np.zeros(24000, dtype=np.float32)
137
-
138
  @spaces.GPU()
139
  def render_podcast(api_key, script, voice1, voice2, num_hosts):
140
  try:
@@ -144,7 +149,7 @@ def render_podcast(api_key, script, voice1, voice2, num_hosts):
144
  for i, line in enumerate(lines):
145
  voice = voice1 if num_hosts == 1 or i % 2 == 0 else voice2
146
  try:
147
- audio = text_to_speech(line, voice)
148
  audio_segments.append(audio)
149
  except Exception as e:
150
  logger.error(f"Error processing audio segment: {str(e)}")
 
3
  import numpy as np
4
  import re
5
  import torch
6
+ import torchaudio
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from huggingface_hub import snapshot_download, login
9
  import logging
10
  import os
11
  import spaces
12
  import warnings
 
13
 
14
  # Set up logging
15
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
16
  logger = logging.getLogger(__name__)
17
 
18
  # Suppress specific warnings
19
+ warnings.filterwarnings("ignore", category=UserWarning)
20
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
 
21
 
22
  def get_device():
23
  if torch.cuda.is_available():
 
114
  if model is None or tokenizer is None:
115
  load_model()
116
 
117
+ # Remove emotion tags for TTS processing
118
+ clean_text = re.sub(r'<[^>]+>', '', text)
119
+
120
+ inputs = tokenizer(clean_text, return_tensors="pt").to(device)
121
  with torch.no_grad():
122
  output = model.generate(**inputs, max_new_tokens=256)
123
+
124
+ # Convert output tensor to mel spectrogram
125
+ mel = output[0].cpu()
126
+
127
+ # Normalize the mel spectrogram
128
+ mel = (mel - mel.min()) / (mel.max() - mel.min())
129
+
130
+ # Convert mel spectrogram to audio using torchaudio
131
+ griffin_lim = torchaudio.transforms.GriffinLim(n_fft=2048, n_iter=10)
132
+ audio = griffin_lim(mel.unsqueeze(0))
133
+
134
+ # Convert to numpy array and ensure it's in the correct format
135
+ audio_np = audio.squeeze().numpy()
136
+ audio_np = np.clip(audio_np, -1, 1)
137
+
138
+ return (24000, audio_np.astype(np.float32)) # Assuming 24kHz sample rate
139
  except Exception as e:
140
  logger.error(f"Error in text_to_speech: {str(e)}")
141
  raise
142
 
 
 
 
 
 
 
 
 
 
 
143
  @spaces.GPU()
144
  def render_podcast(api_key, script, voice1, voice2, num_hosts):
145
  try:
 
149
  for i, line in enumerate(lines):
150
  voice = voice1 if num_hosts == 1 or i % 2 == 0 else voice2
151
  try:
152
+ _, audio = text_to_speech(line, voice)
153
  audio_segments.append(audio)
154
  except Exception as e:
155
  logger.error(f"Error processing audio segment: {str(e)}")