musdfakoc commited on
Commit
f36caa0
·
verified ·
1 Parent(s): f7e6aa6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -57
app.py CHANGED
@@ -128,35 +128,25 @@ def load_gan_model(generator, model_path, device):
128
  generator.eval() # Set the model to evaluation mode
129
  return generator
130
 
 
 
 
 
131
  def magnitude_to_complex_spectrogram(magnitude_spectrogram):
132
- # Create a zero-phase tensor with the same shape as the magnitude spectrogram
133
  zero_phase = torch.zeros_like(magnitude_spectrogram)
134
-
135
- # Create a complex-valued spectrogram using the magnitude and zero phase
136
- complex_spectrogram = torch.complex(magnitude_spectrogram, zero_phase)
137
-
138
  return complex_spectrogram
139
 
 
140
  def spectrogram_to_audio(magnitude_spectrogram):
141
- # Convert magnitude-only spectrogram to complex format
142
  complex_spectrogram = magnitude_to_complex_spectrogram(magnitude_spectrogram)
143
-
144
- # Provide a rectangular window to suppress the warning
145
- window = torch.ones(n_fft, device=complex_spectrogram.device)
146
-
147
- # Inverse STFT to convert the spectrogram back to audio
148
- audio = torch.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length, window=window)
149
-
150
  return audio
151
 
152
-
153
-
154
  def generate_audio_from_image(image):
155
- if image is None:
156
- raise ValueError("The uploaded image is 'None'. Please check the Gradio input.")
157
-
158
- # Preprocess the image
159
- test_img = image_transform(image).unsqueeze(0).to(device)
160
 
161
  # Generate sound spectrogram from the image using the loaded generator
162
  with torch.no_grad():
@@ -165,36 +155,8 @@ def generate_audio_from_image(image):
165
  # Convert the generated spectrogram to audio
166
  generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu())
167
 
168
- # Convert the audio to a NumPy array
169
- generated_audio = generated_audio.numpy()
170
-
171
- # Normalize the audio between -1 and 1
172
- max_value = np.abs(generated_audio).max()
173
- if max_value > 0:
174
- generated_audio = generated_audio / max_value
175
-
176
- # Convert the audio to 16-bit integer format
177
- generated_audio = np.int16(generated_audio * 32767)
178
-
179
- # Ensure audio is in stereo format (samples, channels)
180
- if generated_audio.ndim == 1: # If mono, make it stereo
181
- generated_audio = np.expand_dims(generated_audio, axis=-1)
182
-
183
- # Transpose to ensure the shape is (samples, channels)
184
- generated_audio = generated_audio.T
185
-
186
- # Convert sample_rate to a scalar integer
187
- sample_rate_scalar = int(sample_rate)
188
-
189
- # Debug: Ensure everything is correct before returning
190
- print(f"Returning audio data of shape {generated_audio.shape}, dtype {generated_audio.dtype}")
191
- print(f"Returning sample rate: {sample_rate_scalar}, dtype {type(sample_rate_scalar)}")
192
-
193
- # Return the tuple (sample_rate, audio_data)
194
- return (sample_rate_scalar, generated_audio)
195
-
196
-
197
-
198
 
199
  # Gradio Interface
200
  def main():
@@ -203,16 +165,13 @@ def main():
203
  generator = Generator(output_time_frames).to(device)
204
 
205
  # Load the pre-trained model
206
- model_path = './gan_model.pth' # Ensure the model is in the correct relative path
207
  generator = load_gan_model(generator, model_path, device)
208
 
209
- iface = gr.Interface(fn=generate_audio_from_image, inputs=gr.Image(type="pil"), outputs=gr.Audio(type="numpy", label="Generated Audio"), title="Image to Sound Generation")
210
-
211
-
212
-
213
-
214
  iface.launch()
215
 
216
-
217
  if __name__ == "__main__":
218
  main()
 
 
128
  generator.eval() # Set the model to evaluation mode
129
  return generator
130
 
131
+
132
+ # Generator model class definitions remain the same as in your original code.
133
+
134
+ # Convert magnitude-only spectrogram to complex format by assuming zero phase
135
  def magnitude_to_complex_spectrogram(magnitude_spectrogram):
 
136
  zero_phase = torch.zeros_like(magnitude_spectrogram)
137
+ complex_spectrogram = torch.stack([magnitude_spectrogram, zero_phase], dim=-1)
 
 
 
138
  return complex_spectrogram
139
 
140
+ # Convert spectrogram back to audio using inverse STFT
141
  def spectrogram_to_audio(magnitude_spectrogram):
142
+ magnitude_spectrogram = torch.expm1(magnitude_spectrogram)
143
  complex_spectrogram = magnitude_to_complex_spectrogram(magnitude_spectrogram)
144
+ audio = torch.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length)
 
 
 
 
 
 
145
  return audio
146
 
147
+ # Function to generate audio from an uploaded image
 
148
  def generate_audio_from_image(image):
149
+ test_img = image_transform(image).unsqueeze(0).to(device) # Preprocess image
 
 
 
 
150
 
151
  # Generate sound spectrogram from the image using the loaded generator
152
  with torch.no_grad():
 
155
  # Convert the generated spectrogram to audio
156
  generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu())
157
 
158
+ # Convert audio tensor to numpy and return it for Gradio to handle
159
+ return (sample_rate, generated_audio.numpy())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  # Gradio Interface
162
  def main():
 
165
  generator = Generator(output_time_frames).to(device)
166
 
167
  # Load the pre-trained model
168
+ model_path = '/path/to/your/model/gan_model_100e_16b.pth' # Change this path
169
  generator = load_gan_model(generator, model_path, device)
170
 
171
+ # Gradio interface: allow users to upload an image and generate audio
172
+ iface = gr.Interface(fn=generate_audio_from_image, inputs=gr.Image(type="pil"), outputs=gr.Audio(type="numpy", label="Generated Audio"))
 
 
 
173
  iface.launch()
174
 
 
175
  if __name__ == "__main__":
176
  main()
177
+