Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -129,41 +129,58 @@ def load_gan_model(generator, model_path, device):
|
|
129 |
return generator
|
130 |
|
131 |
|
132 |
-
# Generator model class definitions remain the same as in your original code.
|
133 |
-
|
134 |
-
# Convert magnitude-only spectrogram to complex format by assuming zero phase
|
135 |
def magnitude_to_complex_spectrogram(magnitude_spectrogram):
|
|
|
|
|
|
|
136 |
zero_phase = torch.zeros_like(magnitude_spectrogram)
|
137 |
complex_spectrogram = torch.stack([magnitude_spectrogram, zero_phase], dim=-1)
|
|
|
|
|
|
|
|
|
|
|
138 |
return complex_spectrogram
|
139 |
|
|
|
140 |
def spectrogram_to_audio(magnitude_spectrogram):
|
141 |
# Perform inverse log scaling
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
144 |
# Convert magnitude-only spectrogram to complex format
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
# Inverse STFT to convert the spectrogram back to audio
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
# Normalize audio to the range [-1, 1] (standard audio range)
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
153 |
|
154 |
# Clip the audio to ensure it fits in the range [-1, 1]
|
155 |
audio = torch.clamp(audio, min=-1, max=1)
|
156 |
|
157 |
-
# Check for NaNs in the audio tensor
|
158 |
-
if torch.isnan(audio).any():
|
159 |
-
raise ValueError("Generated audio contains NaN values.")
|
160 |
-
|
161 |
# Convert to 16-bit PCM format by scaling and casting to int16
|
162 |
audio = (audio * 32767).short()
|
163 |
|
164 |
-
# Convert audio tensor to numpy array for Gradio
|
165 |
-
audio = audio.cpu().numpy().astype(np.int16)
|
166 |
-
|
167 |
return audio
|
168 |
|
169 |
|
|
|
129 |
return generator
|
130 |
|
131 |
|
|
|
|
|
|
|
132 |
def magnitude_to_complex_spectrogram(magnitude_spectrogram):
|
133 |
+
# Clip values to avoid extreme values or potential invalid inputs
|
134 |
+
magnitude_spectrogram = torch.clamp(magnitude_spectrogram, min=1e-10, max=1e5)
|
135 |
+
|
136 |
zero_phase = torch.zeros_like(magnitude_spectrogram)
|
137 |
complex_spectrogram = torch.stack([magnitude_spectrogram, zero_phase], dim=-1)
|
138 |
+
|
139 |
+
# Check for NaNs in the complex spectrogram
|
140 |
+
if torch.isnan(complex_spectrogram).any():
|
141 |
+
raise ValueError("Complex spectrogram contains NaN values.")
|
142 |
+
|
143 |
return complex_spectrogram
|
144 |
|
145 |
+
|
146 |
def spectrogram_to_audio(magnitude_spectrogram):
|
147 |
# Perform inverse log scaling
|
148 |
+
try:
|
149 |
+
magnitude_spectrogram = torch.expm1(magnitude_spectrogram)
|
150 |
+
if torch.isnan(magnitude_spectrogram).any():
|
151 |
+
raise ValueError("NaN values found in magnitude_spectrogram after expm1.")
|
152 |
+
except Exception as e:
|
153 |
+
raise ValueError(f"Error in expm1 operation: {e}")
|
154 |
+
|
155 |
# Convert magnitude-only spectrogram to complex format
|
156 |
+
try:
|
157 |
+
complex_spectrogram = magnitude_to_complex_spectrogram(magnitude_spectrogram)
|
158 |
+
if torch.isnan(complex_spectrogram).any():
|
159 |
+
raise ValueError("Complex spectrogram contains NaN values after conversion.")
|
160 |
+
except Exception as e:
|
161 |
+
raise ValueError(f"Error in complex spectrogram creation: {e}")
|
162 |
|
163 |
# Inverse STFT to convert the spectrogram back to audio
|
164 |
+
try:
|
165 |
+
audio = torch.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length)
|
166 |
+
if torch.isnan(audio).any():
|
167 |
+
raise ValueError("Generated audio contains NaN values after istft.")
|
168 |
+
except Exception as e:
|
169 |
+
raise ValueError(f"Error in istft operation: {e}")
|
170 |
|
171 |
# Normalize audio to the range [-1, 1] (standard audio range)
|
172 |
+
try:
|
173 |
+
if torch.max(torch.abs(audio)) != 0:
|
174 |
+
audio = audio / torch.max(torch.abs(audio))
|
175 |
+
except Exception as e:
|
176 |
+
raise ValueError(f"Error in audio normalization: {e}")
|
177 |
|
178 |
# Clip the audio to ensure it fits in the range [-1, 1]
|
179 |
audio = torch.clamp(audio, min=-1, max=1)
|
180 |
|
|
|
|
|
|
|
|
|
181 |
# Convert to 16-bit PCM format by scaling and casting to int16
|
182 |
audio = (audio * 32767).short()
|
183 |
|
|
|
|
|
|
|
184 |
return audio
|
185 |
|
186 |
|