Update app.py
Browse files
app.py
CHANGED
@@ -3,22 +3,21 @@ import google.generativeai as genai
|
|
3 |
import numpy as np
|
4 |
import re
|
5 |
import torch
|
|
|
6 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
from huggingface_hub import snapshot_download, login
|
8 |
import logging
|
9 |
import os
|
10 |
import spaces
|
11 |
import warnings
|
12 |
-
import librosa
|
13 |
|
14 |
# Set up logging
|
15 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
18 |
# Suppress specific warnings
|
19 |
-
warnings.filterwarnings("ignore", category=UserWarning
|
20 |
-
warnings.filterwarnings("ignore", category=RuntimeWarning
|
21 |
-
warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in cast")
|
22 |
|
23 |
def get_device():
|
24 |
if torch.cuda.is_available():
|
@@ -115,26 +114,32 @@ def text_to_speech(text, voice):
|
|
115 |
if model is None or tokenizer is None:
|
116 |
load_model()
|
117 |
|
118 |
-
|
|
|
|
|
|
|
119 |
with torch.no_grad():
|
120 |
output = model.generate(**inputs, max_new_tokens=256)
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
except Exception as e:
|
125 |
logger.error(f"Error in text_to_speech: {str(e)}")
|
126 |
raise
|
127 |
|
128 |
-
def mel_to_audio(mel):
|
129 |
-
try:
|
130 |
-
# Convert mel spectrogram to audio using librosa
|
131 |
-
audio = librosa.feature.inverse.mel_to_audio(mel, sr=24000, n_iter=10)
|
132 |
-
return audio
|
133 |
-
except Exception as e:
|
134 |
-
logger.error(f"Error in mel_to_audio conversion: {str(e)}")
|
135 |
-
# Return silence if conversion fails
|
136 |
-
return np.zeros(24000, dtype=np.float32)
|
137 |
-
|
138 |
@spaces.GPU()
|
139 |
def render_podcast(api_key, script, voice1, voice2, num_hosts):
|
140 |
try:
|
@@ -144,7 +149,7 @@ def render_podcast(api_key, script, voice1, voice2, num_hosts):
|
|
144 |
for i, line in enumerate(lines):
|
145 |
voice = voice1 if num_hosts == 1 or i % 2 == 0 else voice2
|
146 |
try:
|
147 |
-
audio = text_to_speech(line, voice)
|
148 |
audio_segments.append(audio)
|
149 |
except Exception as e:
|
150 |
logger.error(f"Error processing audio segment: {str(e)}")
|
|
|
3 |
import numpy as np
|
4 |
import re
|
5 |
import torch
|
6 |
+
import torchaudio
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
from huggingface_hub import snapshot_download, login
|
9 |
import logging
|
10 |
import os
|
11 |
import spaces
|
12 |
import warnings
|
|
|
13 |
|
14 |
# Set up logging
|
15 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
18 |
# Suppress specific warnings
|
19 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
20 |
+
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
|
|
21 |
|
22 |
def get_device():
|
23 |
if torch.cuda.is_available():
|
|
|
114 |
if model is None or tokenizer is None:
|
115 |
load_model()
|
116 |
|
117 |
+
# Remove emotion tags for TTS processing
|
118 |
+
clean_text = re.sub(r'<[^>]+>', '', text)
|
119 |
+
|
120 |
+
inputs = tokenizer(clean_text, return_tensors="pt").to(device)
|
121 |
with torch.no_grad():
|
122 |
output = model.generate(**inputs, max_new_tokens=256)
|
123 |
+
|
124 |
+
# Convert output tensor to mel spectrogram
|
125 |
+
mel = output[0].cpu()
|
126 |
+
|
127 |
+
# Normalize the mel spectrogram
|
128 |
+
mel = (mel - mel.min()) / (mel.max() - mel.min())
|
129 |
+
|
130 |
+
# Convert mel spectrogram to audio using torchaudio
|
131 |
+
griffin_lim = torchaudio.transforms.GriffinLim(n_fft=2048, n_iter=10)
|
132 |
+
audio = griffin_lim(mel.unsqueeze(0))
|
133 |
+
|
134 |
+
# Convert to numpy array and ensure it's in the correct format
|
135 |
+
audio_np = audio.squeeze().numpy()
|
136 |
+
audio_np = np.clip(audio_np, -1, 1)
|
137 |
+
|
138 |
+
return (24000, audio_np.astype(np.float32)) # Assuming 24kHz sample rate
|
139 |
except Exception as e:
|
140 |
logger.error(f"Error in text_to_speech: {str(e)}")
|
141 |
raise
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
@spaces.GPU()
|
144 |
def render_podcast(api_key, script, voice1, voice2, num_hosts):
|
145 |
try:
|
|
|
149 |
for i, line in enumerate(lines):
|
150 |
voice = voice1 if num_hosts == 1 or i % 2 == 0 else voice2
|
151 |
try:
|
152 |
+
_, audio = text_to_speech(line, voice)
|
153 |
audio_segments.append(audio)
|
154 |
except Exception as e:
|
155 |
logger.error(f"Error processing audio segment: {str(e)}")
|