Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -128,35 +128,25 @@ def load_gan_model(generator, model_path, device):
|
|
128 |
generator.eval() # Set the model to evaluation mode
|
129 |
return generator
|
130 |
|
|
|
|
|
|
|
|
|
131 |
def magnitude_to_complex_spectrogram(magnitude_spectrogram):
|
132 |
-
# Create a zero-phase tensor with the same shape as the magnitude spectrogram
|
133 |
zero_phase = torch.zeros_like(magnitude_spectrogram)
|
134 |
-
|
135 |
-
# Create a complex-valued spectrogram using the magnitude and zero phase
|
136 |
-
complex_spectrogram = torch.complex(magnitude_spectrogram, zero_phase)
|
137 |
-
|
138 |
return complex_spectrogram
|
139 |
|
|
|
140 |
def spectrogram_to_audio(magnitude_spectrogram):
|
141 |
-
|
142 |
complex_spectrogram = magnitude_to_complex_spectrogram(magnitude_spectrogram)
|
143 |
-
|
144 |
-
# Provide a rectangular window to suppress the warning
|
145 |
-
window = torch.ones(n_fft, device=complex_spectrogram.device)
|
146 |
-
|
147 |
-
# Inverse STFT to convert the spectrogram back to audio
|
148 |
-
audio = torch.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length, window=window)
|
149 |
-
|
150 |
return audio
|
151 |
|
152 |
-
|
153 |
-
|
154 |
def generate_audio_from_image(image):
|
155 |
-
|
156 |
-
raise ValueError("The uploaded image is 'None'. Please check the Gradio input.")
|
157 |
-
|
158 |
-
# Preprocess the image
|
159 |
-
test_img = image_transform(image).unsqueeze(0).to(device)
|
160 |
|
161 |
# Generate sound spectrogram from the image using the loaded generator
|
162 |
with torch.no_grad():
|
@@ -165,36 +155,8 @@ def generate_audio_from_image(image):
|
|
165 |
# Convert the generated spectrogram to audio
|
166 |
generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu())
|
167 |
|
168 |
-
# Convert
|
169 |
-
|
170 |
-
|
171 |
-
# Normalize the audio between -1 and 1
|
172 |
-
max_value = np.abs(generated_audio).max()
|
173 |
-
if max_value > 0:
|
174 |
-
generated_audio = generated_audio / max_value
|
175 |
-
|
176 |
-
# Convert the audio to 16-bit integer format
|
177 |
-
generated_audio = np.int16(generated_audio * 32767)
|
178 |
-
|
179 |
-
# Ensure audio is in stereo format (samples, channels)
|
180 |
-
if generated_audio.ndim == 1: # If mono, make it stereo
|
181 |
-
generated_audio = np.expand_dims(generated_audio, axis=-1)
|
182 |
-
|
183 |
-
# Transpose to ensure the shape is (samples, channels)
|
184 |
-
generated_audio = generated_audio.T
|
185 |
-
|
186 |
-
# Convert sample_rate to a scalar integer
|
187 |
-
sample_rate_scalar = int(sample_rate)
|
188 |
-
|
189 |
-
# Debug: Ensure everything is correct before returning
|
190 |
-
print(f"Returning audio data of shape {generated_audio.shape}, dtype {generated_audio.dtype}")
|
191 |
-
print(f"Returning sample rate: {sample_rate_scalar}, dtype {type(sample_rate_scalar)}")
|
192 |
-
|
193 |
-
# Return the tuple (sample_rate, audio_data)
|
194 |
-
return (sample_rate_scalar, generated_audio)
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
|
199 |
# Gradio Interface
|
200 |
def main():
|
@@ -203,16 +165,13 @@ def main():
|
|
203 |
generator = Generator(output_time_frames).to(device)
|
204 |
|
205 |
# Load the pre-trained model
|
206 |
-
model_path = '
|
207 |
generator = load_gan_model(generator, model_path, device)
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
iface.launch()
|
215 |
|
216 |
-
|
217 |
if __name__ == "__main__":
|
218 |
main()
|
|
|
|
128 |
generator.eval() # Set the model to evaluation mode
|
129 |
return generator
|
130 |
|
131 |
+
|
132 |
+
# Generator model class definitions remain the same as in your original code.
|
133 |
+
|
134 |
+
# Convert magnitude-only spectrogram to complex format by assuming zero phase
|
135 |
def magnitude_to_complex_spectrogram(magnitude_spectrogram):
|
|
|
136 |
zero_phase = torch.zeros_like(magnitude_spectrogram)
|
137 |
+
complex_spectrogram = torch.stack([magnitude_spectrogram, zero_phase], dim=-1)
|
|
|
|
|
|
|
138 |
return complex_spectrogram
|
139 |
|
140 |
+
# Convert spectrogram back to audio using inverse STFT
|
141 |
def spectrogram_to_audio(magnitude_spectrogram):
|
142 |
+
magnitude_spectrogram = torch.expm1(magnitude_spectrogram)
|
143 |
complex_spectrogram = magnitude_to_complex_spectrogram(magnitude_spectrogram)
|
144 |
+
audio = torch.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length)
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
return audio
|
146 |
|
147 |
+
# Function to generate audio from an uploaded image
|
|
|
148 |
def generate_audio_from_image(image):
|
149 |
+
test_img = image_transform(image).unsqueeze(0).to(device) # Preprocess image
|
|
|
|
|
|
|
|
|
150 |
|
151 |
# Generate sound spectrogram from the image using the loaded generator
|
152 |
with torch.no_grad():
|
|
|
155 |
# Convert the generated spectrogram to audio
|
156 |
generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu())
|
157 |
|
158 |
+
# Convert audio tensor to numpy and return it for Gradio to handle
|
159 |
+
return (sample_rate, generated_audio.numpy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
# Gradio Interface
|
162 |
def main():
|
|
|
165 |
generator = Generator(output_time_frames).to(device)
|
166 |
|
167 |
# Load the pre-trained model
|
168 |
+
model_path = '/path/to/your/model/gan_model_100e_16b.pth' # Change this path
|
169 |
generator = load_gan_model(generator, model_path, device)
|
170 |
|
171 |
+
# Gradio interface: allow users to upload an image and generate audio
|
172 |
+
iface = gr.Interface(fn=generate_audio_from_image, inputs=gr.Image(type="pil"), outputs=gr.Audio(type="numpy", label="Generated Audio"))
|
|
|
|
|
|
|
173 |
iface.launch()
|
174 |
|
|
|
175 |
if __name__ == "__main__":
|
176 |
main()
|
177 |
+
|