musdfakoc commited on
Commit
dd34179
·
verified ·
1 Parent(s): af04a5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -16
app.py CHANGED
@@ -129,41 +129,58 @@ def load_gan_model(generator, model_path, device):
129
  return generator
130
 
131
 
132
- # Generator model class definitions remain the same as in your original code.
133
-
134
- # Convert magnitude-only spectrogram to complex format by assuming zero phase
135
  def magnitude_to_complex_spectrogram(magnitude_spectrogram):
 
 
 
136
  zero_phase = torch.zeros_like(magnitude_spectrogram)
137
  complex_spectrogram = torch.stack([magnitude_spectrogram, zero_phase], dim=-1)
 
 
 
 
 
138
  return complex_spectrogram
139
 
 
140
  def spectrogram_to_audio(magnitude_spectrogram):
141
  # Perform inverse log scaling
142
- magnitude_spectrogram = torch.expm1(magnitude_spectrogram)
143
-
 
 
 
 
 
144
  # Convert magnitude-only spectrogram to complex format
145
- complex_spectrogram = magnitude_to_complex_spectrogram(magnitude_spectrogram)
 
 
 
 
 
146
 
147
  # Inverse STFT to convert the spectrogram back to audio
148
- audio = torch.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length)
 
 
 
 
 
149
 
150
  # Normalize audio to the range [-1, 1] (standard audio range)
151
- if torch.max(torch.abs(audio)) != 0:
152
- audio = audio / torch.max(torch.abs(audio))
 
 
 
153
 
154
  # Clip the audio to ensure it fits in the range [-1, 1]
155
  audio = torch.clamp(audio, min=-1, max=1)
156
 
157
- # Check for NaNs in the audio tensor
158
- if torch.isnan(audio).any():
159
- raise ValueError("Generated audio contains NaN values.")
160
-
161
  # Convert to 16-bit PCM format by scaling and casting to int16
162
  audio = (audio * 32767).short()
163
 
164
- # Convert audio tensor to numpy array for Gradio
165
- audio = audio.cpu().numpy().astype(np.int16)
166
-
167
  return audio
168
 
169
 
 
129
  return generator
130
 
131
 
 
 
 
132
  def magnitude_to_complex_spectrogram(magnitude_spectrogram):
133
+ # Clip values to avoid extreme values or potential invalid inputs
134
+ magnitude_spectrogram = torch.clamp(magnitude_spectrogram, min=1e-10, max=1e5)
135
+
136
  zero_phase = torch.zeros_like(magnitude_spectrogram)
137
  complex_spectrogram = torch.stack([magnitude_spectrogram, zero_phase], dim=-1)
138
+
139
+ # Check for NaNs in the complex spectrogram
140
+ if torch.isnan(complex_spectrogram).any():
141
+ raise ValueError("Complex spectrogram contains NaN values.")
142
+
143
  return complex_spectrogram
144
 
145
+
146
  def spectrogram_to_audio(magnitude_spectrogram):
147
  # Perform inverse log scaling
148
+ try:
149
+ magnitude_spectrogram = torch.expm1(magnitude_spectrogram)
150
+ if torch.isnan(magnitude_spectrogram).any():
151
+ raise ValueError("NaN values found in magnitude_spectrogram after expm1.")
152
+ except Exception as e:
153
+ raise ValueError(f"Error in expm1 operation: {e}")
154
+
155
  # Convert magnitude-only spectrogram to complex format
156
+ try:
157
+ complex_spectrogram = magnitude_to_complex_spectrogram(magnitude_spectrogram)
158
+ if torch.isnan(complex_spectrogram).any():
159
+ raise ValueError("Complex spectrogram contains NaN values after conversion.")
160
+ except Exception as e:
161
+ raise ValueError(f"Error in complex spectrogram creation: {e}")
162
 
163
  # Inverse STFT to convert the spectrogram back to audio
164
+ try:
165
+ audio = torch.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length)
166
+ if torch.isnan(audio).any():
167
+ raise ValueError("Generated audio contains NaN values after istft.")
168
+ except Exception as e:
169
+ raise ValueError(f"Error in istft operation: {e}")
170
 
171
  # Normalize audio to the range [-1, 1] (standard audio range)
172
+ try:
173
+ if torch.max(torch.abs(audio)) != 0:
174
+ audio = audio / torch.max(torch.abs(audio))
175
+ except Exception as e:
176
+ raise ValueError(f"Error in audio normalization: {e}")
177
 
178
  # Clip the audio to ensure it fits in the range [-1, 1]
179
  audio = torch.clamp(audio, min=-1, max=1)
180
 
 
 
 
 
181
  # Convert to 16-bit PCM format by scaling and casting to int16
182
  audio = (audio * 32767).short()
183
 
 
 
 
184
  return audio
185
 
186