bluenevus commited on
Commit
c10cafd
·
verified ·
1 Parent(s): aa10e55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import re
5
  import torch
6
  import torchaudio
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from huggingface_hub import snapshot_download, login
9
  import logging
@@ -124,12 +125,16 @@ def text_to_speech(text, voice):
124
  # Convert output tensor to mel spectrogram
125
  mel = output[0].cpu()
126
 
 
 
 
 
 
127
  # Normalize the mel spectrogram
128
- mel = (mel - mel.min()) / (mel.max() - mel.min())
129
 
130
  # Convert mel spectrogram to audio using torchaudio
131
- griffin_lim = torchaudio.transforms.GriffinLim(n_fft=2048, n_iter=10)
132
- audio = griffin_lim(mel.unsqueeze(0))
133
 
134
  # Convert to numpy array and ensure it's in the correct format
135
  audio_np = audio.squeeze().numpy()
@@ -139,7 +144,6 @@ def text_to_speech(text, voice):
139
  except Exception as e:
140
  logger.error(f"Error in text_to_speech: {str(e)}")
141
  raise
142
-
143
  @spaces.GPU()
144
  def render_podcast(api_key, script, voice1, voice2, num_hosts):
145
  try:
 
4
  import re
5
  import torch
6
  import torchaudio
7
+ import torchaudio.functional as F
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
  from huggingface_hub import snapshot_download, login
10
  import logging
 
125
  # Convert output tensor to mel spectrogram
126
  mel = output[0].cpu()
127
 
128
+ # Reshape mel to match expected dimensions
129
+ n_mels = 80 # Typical number of mel bands
130
+ time_dim = mel.shape[0]
131
+ mel_reshaped = mel.view(n_mels, -1)
132
+
133
  # Normalize the mel spectrogram
134
+ mel_reshaped = (mel_reshaped - mel_reshaped.min()) / (mel_reshaped.max() - mel_reshaped.min())
135
 
136
  # Convert mel spectrogram to audio using torchaudio
137
+ audio = F.griffinlim(mel_reshaped.unsqueeze(0), n_iter=10, n_fft=2048, hop_length=512, win_length=2048)
 
138
 
139
  # Convert to numpy array and ensure it's in the correct format
140
  audio_np = audio.squeeze().numpy()
 
144
  except Exception as e:
145
  logger.error(f"Error in text_to_speech: {str(e)}")
146
  raise
 
147
  @spaces.GPU()
148
  def render_podcast(api_key, script, voice1, voice2, num_hosts):
149
  try: