Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,212 +1,58 @@
|
|
1 |
-
import torch
|
2 |
-
import torchaudio
|
3 |
import gradio as gr
|
|
|
|
|
|
|
|
|
4 |
from PIL import Image
|
5 |
-
import
|
6 |
-
import torchaudio.transforms as T
|
7 |
-
from torch import nn, optim
|
8 |
-
import torchvision.transforms as transforms
|
9 |
-
from torch.utils.data import Dataset, DataLoader
|
10 |
-
from PIL import Image
|
11 |
-
import os
|
12 |
import numpy as np
|
|
|
|
|
13 |
|
14 |
-
#
|
15 |
-
|
16 |
-
|
17 |
-
#
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
#
|
60 |
-
|
61 |
-
def __init__(self, output_time_frames):
|
62 |
-
super(SoundDecoder, self).__init__()
|
63 |
-
self.fc = nn.Linear(512, 512 * 8 * 8)
|
64 |
-
|
65 |
-
self.decoder = nn.Sequential(
|
66 |
-
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
|
67 |
-
nn.BatchNorm2d(256),
|
68 |
-
nn.ReLU(),
|
69 |
-
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
|
70 |
-
nn.BatchNorm2d(128),
|
71 |
-
nn.ReLU(),
|
72 |
-
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
|
73 |
-
nn.BatchNorm2d(64),
|
74 |
-
nn.ReLU(),
|
75 |
-
nn.ConvTranspose2d(64, n_channels, kernel_size=4, stride=2, padding=1),
|
76 |
-
)
|
77 |
-
|
78 |
-
# Modify the upsample to exactly match the real spectrogram size (108 time frames)
|
79 |
-
self.upsample = nn.Upsample(size=(n_fft // 2 + 1, 108), mode='bilinear', align_corners=True)
|
80 |
-
|
81 |
-
def forward(self, x):
|
82 |
-
x = self.fc(x)
|
83 |
-
x = x.view(x.size(0), 512, 8, 8)
|
84 |
-
x = self.decoder(x)
|
85 |
-
x = self.upsample(x)
|
86 |
-
# Debugging shape
|
87 |
-
print(f'Generated spectrogram shape: {x.shape}')
|
88 |
-
return x
|
89 |
-
|
90 |
-
# Generator model
|
91 |
-
class Generator(nn.Module):
|
92 |
-
def __init__(self, output_time_frames):
|
93 |
-
super(Generator, self).__init__()
|
94 |
-
self.encoder = ImageEncoder()
|
95 |
-
self.decoder = SoundDecoder(output_time_frames)
|
96 |
-
|
97 |
-
def forward(self, img):
|
98 |
-
# Debugging: Image encoder
|
99 |
-
encoded_features = self.encoder(img)
|
100 |
-
print(f"Encoded features shape (from Image Encoder): {encoded_features.shape}")
|
101 |
-
|
102 |
-
# Debugging: Sound decoder
|
103 |
-
generated_spectrogram = self.decoder(encoded_features)
|
104 |
-
print(f"Generated spectrogram shape (from Sound Decoder): {generated_spectrogram.shape}")
|
105 |
-
|
106 |
-
return generated_spectrogram
|
107 |
-
|
108 |
-
|
109 |
-
# Function to generate and save audio from a test image using the pre-trained GAN model
|
110 |
-
def test_model(generator, test_img_path, output_audio_path, device):
|
111 |
-
# Load and preprocess test image
|
112 |
-
test_img = Image.open(test_img_path).convert('RGB')
|
113 |
-
test_img = image_transform(test_img).unsqueeze(0).to(device) # Add batch dimension
|
114 |
-
|
115 |
-
# Generate sound spectrogram from the image
|
116 |
-
with torch.no_grad(): # Disable gradient calculation for inference
|
117 |
-
generated_spectrogram = generator(test_img)
|
118 |
-
|
119 |
-
# Debugging: Check generated spectrogram shape
|
120 |
-
print(f"Generated spectrogram shape: {generated_spectrogram.shape}")
|
121 |
-
|
122 |
-
# Convert the generated spectrogram to audio
|
123 |
-
generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu()) # Remove batch dimension
|
124 |
-
|
125 |
-
# Load the pre-trained GAN model
|
126 |
-
def load_gan_model(generator, model_path, device):
|
127 |
-
generator.load_state_dict(torch.load(model_path, map_location=device))
|
128 |
-
generator.eval() # Set the model to evaluation mode
|
129 |
-
return generator
|
130 |
-
|
131 |
-
|
132 |
-
def magnitude_to_complex_spectrogram(magnitude_spectrogram):
|
133 |
-
# Clip values to avoid extreme values or potential invalid inputs
|
134 |
-
magnitude_spectrogram = torch.clamp(magnitude_spectrogram, min=1e-10, max=1e5)
|
135 |
-
|
136 |
-
zero_phase = torch.zeros_like(magnitude_spectrogram)
|
137 |
-
complex_spectrogram = torch.stack([magnitude_spectrogram, zero_phase], dim=-1)
|
138 |
-
|
139 |
-
# Check for NaNs in the complex spectrogram
|
140 |
-
if torch.isnan(complex_spectrogram).any():
|
141 |
-
raise ValueError("Complex spectrogram contains NaN values.")
|
142 |
-
|
143 |
-
return complex_spectrogram
|
144 |
-
|
145 |
-
|
146 |
-
def spectrogram_to_audio(magnitude_spectrogram):
|
147 |
-
# Perform inverse log scaling to undo any log scaling
|
148 |
-
magnitude_spectrogram = torch.expm1(magnitude_spectrogram)
|
149 |
-
|
150 |
-
# Convert magnitude-only spectrogram to complex format (real part and zero imaginary)
|
151 |
-
zero_phase = torch.zeros_like(magnitude_spectrogram)
|
152 |
-
complex_spectrogram = torch.stack([magnitude_spectrogram, zero_phase], dim=-1)
|
153 |
-
|
154 |
-
# Inverse STFT to convert the spectrogram back to time-domain audio
|
155 |
-
audio = torch.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length)
|
156 |
-
|
157 |
-
# Handle NaNs or Infs in the audio and replace them with zeros
|
158 |
-
audio = torch.nan_to_num(audio, nan=0.0, posinf=0.0, neginf=0.0)
|
159 |
-
|
160 |
-
# Normalize the audio to the range [-1, 1]
|
161 |
-
if torch.max(torch.abs(audio)) != 0:
|
162 |
-
audio = audio / torch.max(torch.abs(audio))
|
163 |
-
|
164 |
-
# Clip the audio to ensure no values fall outside the range [-1, 1]
|
165 |
-
audio = torch.clamp(audio, min=-1, max=1)
|
166 |
-
|
167 |
-
# Convert to 16-bit PCM format by scaling and casting to int16
|
168 |
-
audio = (audio * 32767).short()
|
169 |
-
|
170 |
-
# Ensure the audio is in the valid range for int16 [-32768, 32767]
|
171 |
-
audio = torch.clamp(audio, min=-32768, max=32767)
|
172 |
-
|
173 |
-
# Convert the audio to a NumPy array of int16
|
174 |
-
audio_numpy = audio.cpu().numpy().astype(np.int16)
|
175 |
-
|
176 |
-
return audio_numpy
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
def generate_audio_from_image(image):
|
182 |
-
test_img = image_transform(image).unsqueeze(0).to(device) # Preprocess the image
|
183 |
-
|
184 |
-
# Generate a sound spectrogram from the image using the loaded generator
|
185 |
-
with torch.no_grad():
|
186 |
-
generated_spectrogram = generator(test_img)
|
187 |
-
|
188 |
-
# Convert the generated spectrogram to time-domain audio
|
189 |
-
generated_audio_numpy = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu())
|
190 |
-
|
191 |
-
# Return the sample rate and the audio in the correct format for Gradio
|
192 |
-
return (sample_rate, generated_audio_numpy)
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
# Gradio Interface
|
197 |
-
def main():
|
198 |
-
global generator # Declare the generator object globally
|
199 |
-
# Instantiate your Generator model
|
200 |
-
generator = Generator(output_time_frames).to(device)
|
201 |
-
|
202 |
-
# Load the pre-trained model
|
203 |
-
model_path = './gan_model.pth' # Change this path
|
204 |
-
generator = load_gan_model(generator, model_path, device)
|
205 |
-
|
206 |
-
# Gradio interface: allow users to upload an image and generate audio
|
207 |
-
iface = gr.Interface(fn=generate_audio_from_image, inputs=gr.Image(type="pil"), outputs=gr.Audio(type="numpy", label="Generated Audio"))
|
208 |
-
iface.launch()
|
209 |
-
|
210 |
-
if __name__ == "__main__":
|
211 |
-
main()
|
212 |
-
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from keras.models import load_model
|
3 |
+
from tensorflow.keras.utils import img_to_array
|
4 |
+
from tensorflow.keras.utils import load_img
|
5 |
+
from numpy import expand_dims
|
6 |
from PIL import Image
|
7 |
+
import librosa
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
import numpy as np
|
9 |
+
import soundfile as sf
|
10 |
+
import os
|
11 |
|
12 |
+
# Load your Pix2Pix model (make sure the path is correct)
|
13 |
+
model = load_model('./model_022600.h5', compile=False)
|
14 |
+
|
15 |
+
# Function to process the input image and convert to audio
|
16 |
+
def process_image(input_image):
|
17 |
+
# Load and preprocess the input image
|
18 |
+
def load_image(image, size=(256, 256)):
|
19 |
+
image = image.resize(size)
|
20 |
+
pixels = img_to_array(image)
|
21 |
+
pixels = (pixels - 127.5) / 127.5
|
22 |
+
pixels = expand_dims(pixels, 0)
|
23 |
+
return pixels
|
24 |
+
|
25 |
+
# Preprocess the input
|
26 |
+
src_image = load_image(input_image)
|
27 |
+
|
28 |
+
# Generate output using the Pix2Pix model
|
29 |
+
gen_image = model.predict(src_image)
|
30 |
+
gen_image = (gen_image + 1) / 2.0 # scale to [0, 1]
|
31 |
+
|
32 |
+
# Resize the generated image to original spectrogram size
|
33 |
+
orig_size = (1293, 512)
|
34 |
+
gen_image_resized = Image.fromarray((gen_image[0] * 255).astype('uint8')).resize(orig_size).convert('F')
|
35 |
+
|
36 |
+
# Convert the image to a numpy array (spectrogram)
|
37 |
+
img = np.array(gen_image_resized)
|
38 |
+
|
39 |
+
# Convert the spectrogram back to audio using librosa
|
40 |
+
wav = librosa.feature.inverse.mel_to_audio(img, sr=44100, n_fft=2048, hop_length=512)
|
41 |
+
|
42 |
+
# Save the audio file to a temporary location
|
43 |
+
audio_file = "generated_audio.wav"
|
44 |
+
sf.write(audio_file, wav, samplerate=44100)
|
45 |
+
|
46 |
+
return audio_file
|
47 |
+
|
48 |
+
# Create a Gradio interface
|
49 |
+
interface = gr.Interface(
|
50 |
+
fn=process_image,
|
51 |
+
inputs=gr.Image(type="pil"), # Input is an image
|
52 |
+
outputs=gr.Audio(type="file"), # Output is an audio file
|
53 |
+
title="Image to Audio Generator", # App title
|
54 |
+
description="Upload an image (preferably a spectrogram), and get an audio file generated using Pix2Pix.",
|
55 |
+
)
|
56 |
+
|
57 |
+
# Launch the interface
|
58 |
+
interface.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|