musdfakoc commited on
Commit
ebb57ae
·
verified ·
1 Parent(s): 5a6b109

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -208
app.py CHANGED
@@ -1,212 +1,58 @@
1
- import torch
2
- import torchaudio
3
  import gradio as gr
 
 
 
 
4
  from PIL import Image
5
- import torchvision.transforms as transforms
6
- import torchaudio.transforms as T
7
- from torch import nn, optim
8
- import torchvision.transforms as transforms
9
- from torch.utils.data import Dataset, DataLoader
10
- from PIL import Image
11
- import os
12
  import numpy as np
 
 
13
 
14
- # Set device to 'cpu' or 'cuda' if available
15
- device = torch.device('cpu')
16
-
17
- # Parameters
18
- sample_rate = 44100 # 44.1kHz stereo sounds
19
- n_fft = 4096 # FFT size
20
- hop_length = 2048 # Hop length for STFT
21
- duration = 5 # Duration of the sound files (5 seconds)
22
- n_channels = 2 # Stereo sound
23
- output_time_frames = duration * sample_rate // hop_length # Number of time frames in the spectrogram
24
-
25
- stft_transform = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, win_length=n_fft)
26
-
27
- image_transform = transforms.Compose([
28
- transforms.Resize((256, 256)),
29
- transforms.ToTensor(),
30
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1]
31
- ])
32
-
33
- # Image Encoder (for the Generator)
34
- class ImageEncoder(nn.Module):
35
- def __init__(self):
36
- super(ImageEncoder, self).__init__()
37
- self.encoder = nn.Sequential(
38
- nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
39
- nn.BatchNorm2d(64),
40
- nn.ReLU(),
41
- nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
42
- nn.BatchNorm2d(128),
43
- nn.ReLU(),
44
- nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
45
- nn.BatchNorm2d(256),
46
- nn.ReLU(),
47
- nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
48
- nn.BatchNorm2d(512),
49
- nn.ReLU()
50
- )
51
- self.fc = nn.Linear(512 * 16 * 16, 512)
52
-
53
- def forward(self, x):
54
- x = self.encoder(x)
55
- x = x.view(x.size(0), -1)
56
- return self.fc(x)
57
-
58
-
59
- # Sound Decoder (for the Generator)
60
- class SoundDecoder(nn.Module):
61
- def __init__(self, output_time_frames):
62
- super(SoundDecoder, self).__init__()
63
- self.fc = nn.Linear(512, 512 * 8 * 8)
64
-
65
- self.decoder = nn.Sequential(
66
- nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
67
- nn.BatchNorm2d(256),
68
- nn.ReLU(),
69
- nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
70
- nn.BatchNorm2d(128),
71
- nn.ReLU(),
72
- nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
73
- nn.BatchNorm2d(64),
74
- nn.ReLU(),
75
- nn.ConvTranspose2d(64, n_channels, kernel_size=4, stride=2, padding=1),
76
- )
77
-
78
- # Modify the upsample to exactly match the real spectrogram size (108 time frames)
79
- self.upsample = nn.Upsample(size=(n_fft // 2 + 1, 108), mode='bilinear', align_corners=True)
80
-
81
- def forward(self, x):
82
- x = self.fc(x)
83
- x = x.view(x.size(0), 512, 8, 8)
84
- x = self.decoder(x)
85
- x = self.upsample(x)
86
- # Debugging shape
87
- print(f'Generated spectrogram shape: {x.shape}')
88
- return x
89
-
90
- # Generator model
91
- class Generator(nn.Module):
92
- def __init__(self, output_time_frames):
93
- super(Generator, self).__init__()
94
- self.encoder = ImageEncoder()
95
- self.decoder = SoundDecoder(output_time_frames)
96
-
97
- def forward(self, img):
98
- # Debugging: Image encoder
99
- encoded_features = self.encoder(img)
100
- print(f"Encoded features shape (from Image Encoder): {encoded_features.shape}")
101
-
102
- # Debugging: Sound decoder
103
- generated_spectrogram = self.decoder(encoded_features)
104
- print(f"Generated spectrogram shape (from Sound Decoder): {generated_spectrogram.shape}")
105
-
106
- return generated_spectrogram
107
-
108
-
109
- # Function to generate and save audio from a test image using the pre-trained GAN model
110
- def test_model(generator, test_img_path, output_audio_path, device):
111
- # Load and preprocess test image
112
- test_img = Image.open(test_img_path).convert('RGB')
113
- test_img = image_transform(test_img).unsqueeze(0).to(device) # Add batch dimension
114
-
115
- # Generate sound spectrogram from the image
116
- with torch.no_grad(): # Disable gradient calculation for inference
117
- generated_spectrogram = generator(test_img)
118
-
119
- # Debugging: Check generated spectrogram shape
120
- print(f"Generated spectrogram shape: {generated_spectrogram.shape}")
121
-
122
- # Convert the generated spectrogram to audio
123
- generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu()) # Remove batch dimension
124
-
125
- # Load the pre-trained GAN model
126
- def load_gan_model(generator, model_path, device):
127
- generator.load_state_dict(torch.load(model_path, map_location=device))
128
- generator.eval() # Set the model to evaluation mode
129
- return generator
130
-
131
-
132
- def magnitude_to_complex_spectrogram(magnitude_spectrogram):
133
- # Clip values to avoid extreme values or potential invalid inputs
134
- magnitude_spectrogram = torch.clamp(magnitude_spectrogram, min=1e-10, max=1e5)
135
-
136
- zero_phase = torch.zeros_like(magnitude_spectrogram)
137
- complex_spectrogram = torch.stack([magnitude_spectrogram, zero_phase], dim=-1)
138
-
139
- # Check for NaNs in the complex spectrogram
140
- if torch.isnan(complex_spectrogram).any():
141
- raise ValueError("Complex spectrogram contains NaN values.")
142
-
143
- return complex_spectrogram
144
-
145
-
146
- def spectrogram_to_audio(magnitude_spectrogram):
147
- # Perform inverse log scaling to undo any log scaling
148
- magnitude_spectrogram = torch.expm1(magnitude_spectrogram)
149
-
150
- # Convert magnitude-only spectrogram to complex format (real part and zero imaginary)
151
- zero_phase = torch.zeros_like(magnitude_spectrogram)
152
- complex_spectrogram = torch.stack([magnitude_spectrogram, zero_phase], dim=-1)
153
-
154
- # Inverse STFT to convert the spectrogram back to time-domain audio
155
- audio = torch.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length)
156
-
157
- # Handle NaNs or Infs in the audio and replace them with zeros
158
- audio = torch.nan_to_num(audio, nan=0.0, posinf=0.0, neginf=0.0)
159
-
160
- # Normalize the audio to the range [-1, 1]
161
- if torch.max(torch.abs(audio)) != 0:
162
- audio = audio / torch.max(torch.abs(audio))
163
-
164
- # Clip the audio to ensure no values fall outside the range [-1, 1]
165
- audio = torch.clamp(audio, min=-1, max=1)
166
-
167
- # Convert to 16-bit PCM format by scaling and casting to int16
168
- audio = (audio * 32767).short()
169
-
170
- # Ensure the audio is in the valid range for int16 [-32768, 32767]
171
- audio = torch.clamp(audio, min=-32768, max=32767)
172
-
173
- # Convert the audio to a NumPy array of int16
174
- audio_numpy = audio.cpu().numpy().astype(np.int16)
175
-
176
- return audio_numpy
177
-
178
-
179
-
180
-
181
- def generate_audio_from_image(image):
182
- test_img = image_transform(image).unsqueeze(0).to(device) # Preprocess the image
183
-
184
- # Generate a sound spectrogram from the image using the loaded generator
185
- with torch.no_grad():
186
- generated_spectrogram = generator(test_img)
187
-
188
- # Convert the generated spectrogram to time-domain audio
189
- generated_audio_numpy = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu())
190
-
191
- # Return the sample rate and the audio in the correct format for Gradio
192
- return (sample_rate, generated_audio_numpy)
193
-
194
-
195
-
196
- # Gradio Interface
197
- def main():
198
- global generator # Declare the generator object globally
199
- # Instantiate your Generator model
200
- generator = Generator(output_time_frames).to(device)
201
-
202
- # Load the pre-trained model
203
- model_path = './gan_model.pth' # Change this path
204
- generator = load_gan_model(generator, model_path, device)
205
-
206
- # Gradio interface: allow users to upload an image and generate audio
207
- iface = gr.Interface(fn=generate_audio_from_image, inputs=gr.Image(type="pil"), outputs=gr.Audio(type="numpy", label="Generated Audio"))
208
- iface.launch()
209
-
210
- if __name__ == "__main__":
211
- main()
212
-
 
 
 
1
  import gradio as gr
2
+ from keras.models import load_model
3
+ from tensorflow.keras.utils import img_to_array
4
+ from tensorflow.keras.utils import load_img
5
+ from numpy import expand_dims
6
  from PIL import Image
7
+ import librosa
 
 
 
 
 
 
8
  import numpy as np
9
+ import soundfile as sf
10
+ import os
11
 
12
+ # Load your Pix2Pix model (make sure the path is correct)
13
+ model = load_model('./model_022600.h5', compile=False)
14
+
15
+ # Function to process the input image and convert to audio
16
+ def process_image(input_image):
17
+ # Load and preprocess the input image
18
+ def load_image(image, size=(256, 256)):
19
+ image = image.resize(size)
20
+ pixels = img_to_array(image)
21
+ pixels = (pixels - 127.5) / 127.5
22
+ pixels = expand_dims(pixels, 0)
23
+ return pixels
24
+
25
+ # Preprocess the input
26
+ src_image = load_image(input_image)
27
+
28
+ # Generate output using the Pix2Pix model
29
+ gen_image = model.predict(src_image)
30
+ gen_image = (gen_image + 1) / 2.0 # scale to [0, 1]
31
+
32
+ # Resize the generated image to original spectrogram size
33
+ orig_size = (1293, 512)
34
+ gen_image_resized = Image.fromarray((gen_image[0] * 255).astype('uint8')).resize(orig_size).convert('F')
35
+
36
+ # Convert the image to a numpy array (spectrogram)
37
+ img = np.array(gen_image_resized)
38
+
39
+ # Convert the spectrogram back to audio using librosa
40
+ wav = librosa.feature.inverse.mel_to_audio(img, sr=44100, n_fft=2048, hop_length=512)
41
+
42
+ # Save the audio file to a temporary location
43
+ audio_file = "generated_audio.wav"
44
+ sf.write(audio_file, wav, samplerate=44100)
45
+
46
+ return audio_file
47
+
48
+ # Create a Gradio interface
49
+ interface = gr.Interface(
50
+ fn=process_image,
51
+ inputs=gr.Image(type="pil"), # Input is an image
52
+ outputs=gr.Audio(type="file"), # Output is an audio file
53
+ title="Image to Audio Generator", # App title
54
+ description="Upload an image (preferably a spectrogram), and get an audio file generated using Pix2Pix.",
55
+ )
56
+
57
+ # Launch the interface
58
+ interface.launch()