Spaces:
Sleeping
Sleeping
File size: 7,755 Bytes
d076b8a 00093e0 d076b8a af19cb5 d076b8a f36caa0 d076b8a dd34179 d076b8a f36caa0 dd34179 d076b8a dd34179 d076b8a 86448a8 0141d56 dd34179 86448a8 fbb8eef 86448a8 0141d56 755bafb 5af138b 4970be7 0141d56 5af138b 86448a8 2efb66f 86448a8 fbb8eef 86448a8 5af138b 86448a8 755bafb d076b8a fbb8eef 4970be7 d076b8a 86448a8 d076b8a 86448a8 d076b8a 4970be7 755bafb c756113 86448a8 c756113 6deab9e fbb8eef 5af138b d076b8a a2aee3a d076b8a f36caa0 d076b8a f36caa0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
import torch
import torchaudio
import gradio as gr
from PIL import Image
import torchvision.transforms as transforms
import torchaudio.transforms as T
from torch import nn, optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np
# Set device to 'cpu' or 'cuda' if available
device = torch.device('cpu')
# Parameters
sample_rate = 44100 # 44.1kHz stereo sounds
n_fft = 4096 # FFT size
hop_length = 2048 # Hop length for STFT
duration = 5 # Duration of the sound files (5 seconds)
n_channels = 2 # Stereo sound
output_time_frames = duration * sample_rate // hop_length # Number of time frames in the spectrogram
stft_transform = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, win_length=n_fft)
image_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1]
])
# Image Encoder (for the Generator)
class ImageEncoder(nn.Module):
def __init__(self):
super(ImageEncoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.fc = nn.Linear(512 * 16 * 16, 512)
def forward(self, x):
x = self.encoder(x)
x = x.view(x.size(0), -1)
return self.fc(x)
# Sound Decoder (for the Generator)
class SoundDecoder(nn.Module):
def __init__(self, output_time_frames):
super(SoundDecoder, self).__init__()
self.fc = nn.Linear(512, 512 * 8 * 8)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, n_channels, kernel_size=4, stride=2, padding=1),
)
# Modify the upsample to exactly match the real spectrogram size (108 time frames)
self.upsample = nn.Upsample(size=(n_fft // 2 + 1, 108), mode='bilinear', align_corners=True)
def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), 512, 8, 8)
x = self.decoder(x)
x = self.upsample(x)
# Debugging shape
print(f'Generated spectrogram shape: {x.shape}')
return x
# Generator model
class Generator(nn.Module):
def __init__(self, output_time_frames):
super(Generator, self).__init__()
self.encoder = ImageEncoder()
self.decoder = SoundDecoder(output_time_frames)
def forward(self, img):
# Debugging: Image encoder
encoded_features = self.encoder(img)
print(f"Encoded features shape (from Image Encoder): {encoded_features.shape}")
# Debugging: Sound decoder
generated_spectrogram = self.decoder(encoded_features)
print(f"Generated spectrogram shape (from Sound Decoder): {generated_spectrogram.shape}")
return generated_spectrogram
# Function to generate and save audio from a test image using the pre-trained GAN model
def test_model(generator, test_img_path, output_audio_path, device):
# Load and preprocess test image
test_img = Image.open(test_img_path).convert('RGB')
test_img = image_transform(test_img).unsqueeze(0).to(device) # Add batch dimension
# Generate sound spectrogram from the image
with torch.no_grad(): # Disable gradient calculation for inference
generated_spectrogram = generator(test_img)
# Debugging: Check generated spectrogram shape
print(f"Generated spectrogram shape: {generated_spectrogram.shape}")
# Convert the generated spectrogram to audio
generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu()) # Remove batch dimension
# Load the pre-trained GAN model
def load_gan_model(generator, model_path, device):
generator.load_state_dict(torch.load(model_path, map_location=device))
generator.eval() # Set the model to evaluation mode
return generator
def magnitude_to_complex_spectrogram(magnitude_spectrogram):
# Clip values to avoid extreme values or potential invalid inputs
magnitude_spectrogram = torch.clamp(magnitude_spectrogram, min=1e-10, max=1e5)
zero_phase = torch.zeros_like(magnitude_spectrogram)
complex_spectrogram = torch.stack([magnitude_spectrogram, zero_phase], dim=-1)
# Check for NaNs in the complex spectrogram
if torch.isnan(complex_spectrogram).any():
raise ValueError("Complex spectrogram contains NaN values.")
return complex_spectrogram
def spectrogram_to_audio(magnitude_spectrogram):
# Perform inverse log scaling to undo any log scaling
magnitude_spectrogram = torch.expm1(magnitude_spectrogram)
# Convert magnitude-only spectrogram to complex format (real part and zero imaginary)
zero_phase = torch.zeros_like(magnitude_spectrogram)
complex_spectrogram = torch.stack([magnitude_spectrogram, zero_phase], dim=-1)
# Inverse STFT to convert the spectrogram back to time-domain audio
audio = torch.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length)
# Handle NaNs or Infs in the audio and replace them with zeros
audio = torch.nan_to_num(audio, nan=0.0, posinf=0.0, neginf=0.0)
# Normalize the audio to the range [-1, 1]
if torch.max(torch.abs(audio)) != 0:
audio = audio / torch.max(torch.abs(audio))
# Clip the audio to ensure no values fall outside the range [-1, 1]
audio = torch.clamp(audio, min=-1, max=1)
# Convert to 16-bit PCM format by scaling and casting to int16
audio = (audio * 32767).short()
# Ensure the audio is in the valid range for int16 [-32768, 32767]
audio = torch.clamp(audio, min=-32768, max=32767)
# Convert the audio to a NumPy array of int16
audio_numpy = audio.cpu().numpy().astype(np.int16)
return audio_numpy
def generate_audio_from_image(image):
test_img = image_transform(image).unsqueeze(0).to(device) # Preprocess the image
# Generate a sound spectrogram from the image using the loaded generator
with torch.no_grad():
generated_spectrogram = generator(test_img)
# Convert the generated spectrogram to time-domain audio
generated_audio_numpy = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu())
# Return the sample rate and the audio in the correct format for Gradio
return (sample_rate, generated_audio_numpy)
# Gradio Interface
def main():
global generator # Declare the generator object globally
# Instantiate your Generator model
generator = Generator(output_time_frames).to(device)
# Load the pre-trained model
model_path = './gan_model.pth' # Change this path
generator = load_gan_model(generator, model_path, device)
# Gradio interface: allow users to upload an image and generate audio
iface = gr.Interface(fn=generate_audio_from_image, inputs=gr.Image(type="pil"), outputs=gr.Audio(type="numpy", label="Generated Audio"))
iface.launch()
if __name__ == "__main__":
main()
|