Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import numpy as np
|
|
4 |
import re
|
5 |
import torch
|
6 |
import torchaudio
|
|
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
from huggingface_hub import snapshot_download, login
|
9 |
import logging
|
@@ -124,12 +125,16 @@ def text_to_speech(text, voice):
|
|
124 |
# Convert output tensor to mel spectrogram
|
125 |
mel = output[0].cpu()
|
126 |
|
|
|
|
|
|
|
|
|
|
|
127 |
# Normalize the mel spectrogram
|
128 |
-
|
129 |
|
130 |
# Convert mel spectrogram to audio using torchaudio
|
131 |
-
|
132 |
-
audio = griffin_lim(mel.unsqueeze(0))
|
133 |
|
134 |
# Convert to numpy array and ensure it's in the correct format
|
135 |
audio_np = audio.squeeze().numpy()
|
@@ -139,7 +144,6 @@ def text_to_speech(text, voice):
|
|
139 |
except Exception as e:
|
140 |
logger.error(f"Error in text_to_speech: {str(e)}")
|
141 |
raise
|
142 |
-
|
143 |
@spaces.GPU()
|
144 |
def render_podcast(api_key, script, voice1, voice2, num_hosts):
|
145 |
try:
|
|
|
4 |
import re
|
5 |
import torch
|
6 |
import torchaudio
|
7 |
+
import torchaudio.functional as F
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
9 |
from huggingface_hub import snapshot_download, login
|
10 |
import logging
|
|
|
125 |
# Convert output tensor to mel spectrogram
|
126 |
mel = output[0].cpu()
|
127 |
|
128 |
+
# Reshape mel to match expected dimensions
|
129 |
+
n_mels = 80 # Typical number of mel bands
|
130 |
+
time_dim = mel.shape[0]
|
131 |
+
mel_reshaped = mel.view(n_mels, -1)
|
132 |
+
|
133 |
# Normalize the mel spectrogram
|
134 |
+
mel_reshaped = (mel_reshaped - mel_reshaped.min()) / (mel_reshaped.max() - mel_reshaped.min())
|
135 |
|
136 |
# Convert mel spectrogram to audio using torchaudio
|
137 |
+
audio = F.griffinlim(mel_reshaped.unsqueeze(0), n_iter=10, n_fft=2048, hop_length=512, win_length=2048)
|
|
|
138 |
|
139 |
# Convert to numpy array and ensure it's in the correct format
|
140 |
audio_np = audio.squeeze().numpy()
|
|
|
144 |
except Exception as e:
|
145 |
logger.error(f"Error in text_to_speech: {str(e)}")
|
146 |
raise
|
|
|
147 |
@spaces.GPU()
|
148 |
def render_podcast(api_key, script, voice1, voice2, num_hosts):
|
149 |
try:
|