Spaces:
Sleeping
Sleeping
File size: 7,373 Bytes
d076b8a 00093e0 d076b8a af19cb5 d076b8a eaee834 d076b8a eaee834 d076b8a 7a2a54e d076b8a eaee834 7a2a54e eaee834 7a2a54e d076b8a eaee834 d076b8a 99dc7e5 d076b8a 7a2a54e 4809011 ceb7e4c 552fea5 4809011 d076b8a 99dc7e5 d076b8a eaee834 d076b8a c91ba2d 153c20a 99dc7e5 eaee834 d076b8a eaee834 d076b8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import torch
import torchaudio
import gradio as gr
from PIL import Image
import torchvision.transforms as transforms
import torchaudio.transforms as T
from torch import nn, optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np
# Set device to 'cpu' or 'cuda' if available
device = torch.device('cpu')
# Parameters
sample_rate = 44100 # 44.1kHz stereo sounds
n_fft = 4096 # FFT size
hop_length = 2048 # Hop length for STFT
duration = 5 # Duration of the sound files (5 seconds)
n_channels = 2 # Stereo sound
output_time_frames = duration * sample_rate // hop_length # Number of time frames in the spectrogram
stft_transform = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, win_length=n_fft)
image_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1]
])
# Image Encoder (for the Generator)
class ImageEncoder(nn.Module):
def __init__(self):
super(ImageEncoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.fc = nn.Linear(512 * 16 * 16, 512)
def forward(self, x):
x = self.encoder(x)
x = x.view(x.size(0), -1)
return self.fc(x)
# Sound Decoder (for the Generator)
class SoundDecoder(nn.Module):
def __init__(self, output_time_frames):
super(SoundDecoder, self).__init__()
self.fc = nn.Linear(512, 512 * 8 * 8)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, n_channels, kernel_size=4, stride=2, padding=1),
)
# Modify the upsample to exactly match the real spectrogram size (108 time frames)
self.upsample = nn.Upsample(size=(n_fft // 2 + 1, 108), mode='bilinear', align_corners=True)
def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), 512, 8, 8)
x = self.decoder(x)
x = self.upsample(x)
# Debugging shape
print(f'Generated spectrogram shape: {x.shape}')
return x
# Generator model
class Generator(nn.Module):
def __init__(self, output_time_frames):
super(Generator, self).__init__()
self.encoder = ImageEncoder()
self.decoder = SoundDecoder(output_time_frames)
def forward(self, img):
# Debugging: Image encoder
encoded_features = self.encoder(img)
print(f"Encoded features shape (from Image Encoder): {encoded_features.shape}")
# Debugging: Sound decoder
generated_spectrogram = self.decoder(encoded_features)
print(f"Generated spectrogram shape (from Sound Decoder): {generated_spectrogram.shape}")
return generated_spectrogram
# Function to generate and save audio from a test image using the pre-trained GAN model
def test_model(generator, test_img_path, output_audio_path, device):
# Load and preprocess test image
test_img = Image.open(test_img_path).convert('RGB')
test_img = image_transform(test_img).unsqueeze(0).to(device) # Add batch dimension
# Generate sound spectrogram from the image
with torch.no_grad(): # Disable gradient calculation for inference
generated_spectrogram = generator(test_img)
# Debugging: Check generated spectrogram shape
print(f"Generated spectrogram shape: {generated_spectrogram.shape}")
# Convert the generated spectrogram to audio
generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu()) # Remove batch dimension
# Load the pre-trained GAN model
def load_gan_model(generator, model_path, device):
generator.load_state_dict(torch.load(model_path, map_location=device))
generator.eval() # Set the model to evaluation mode
return generator
def magnitude_to_complex_spectrogram(magnitude_spectrogram):
# Create a zero-phase tensor with the same shape as the magnitude spectrogram
zero_phase = torch.zeros_like(magnitude_spectrogram)
# Create a complex-valued spectrogram using the magnitude and zero phase
complex_spectrogram = torch.complex(magnitude_spectrogram, zero_phase)
return complex_spectrogram
def spectrogram_to_audio(magnitude_spectrogram):
# Convert magnitude-only spectrogram to complex format
complex_spectrogram = magnitude_to_complex_spectrogram(magnitude_spectrogram)
# Provide a rectangular window to suppress the warning
window = torch.ones(n_fft, device=complex_spectrogram.device)
# Inverse STFT to convert the spectrogram back to audio
audio = torch.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length, window=window)
return audio
def generate_audio_from_image(image):
if image is None:
raise ValueError("The uploaded image is 'None'. Please check the Gradio input.")
# Ensure the image is in the right format
test_img = image_transform(image).unsqueeze(0).to(device) # Preprocess image
# Generate sound spectrogram from the image using the loaded generator
with torch.no_grad():
generated_spectrogram = generator(test_img)
# Convert the generated spectrogram to audio
generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu())
# Ensure the audio is a NumPy array and properly formatted
generated_audio = generated_audio.numpy()
# Normalize the audio to fit between -1 and 1 for proper playback
max_value = np.abs(generated_audio).max()
if max_value > 0:
generated_audio = generated_audio / max_value
# Convert to the required format (e.g., float32)
generated_audio = generated_audio.astype(np.float32)
# Transpose to (samples, channels) for stereo
generated_audio = generated_audio.T
# Return the audio and the sample rate (ensure sample rate is an integer)
return generated_audio, int(sample_rate)
# Gradio Interface
def main():
global generator # Declare the generator object globally
# Instantiate your Generator model
generator = Generator(output_time_frames).to(device)
# Load the pre-trained model
model_path = './gan_model.pth' # Ensure the model is in the correct relative path
generator = load_gan_model(generator, model_path, device)
iface = gr.Interface(fn=generate_audio_from_image, inputs=gr.Image(type="pil"), outputs=gr.Audio(type="numpy", label="Generated Audio"))
iface.launch()
if __name__ == "__main__":
main()
|