Spaces:
Sleeping
Sleeping
File size: 6,673 Bytes
d076b8a 00093e0 d076b8a af19cb5 d076b8a f36caa0 d076b8a f36caa0 d076b8a f36caa0 d076b8a f36caa0 d076b8a f36caa0 d076b8a f36caa0 d076b8a f36caa0 d076b8a f36caa0 6deab9e d076b8a f36caa0 d076b8a f36caa0 d076b8a f36caa0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import torch
import torchaudio
import gradio as gr
from PIL import Image
import torchvision.transforms as transforms
import torchaudio.transforms as T
from torch import nn, optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np
# Set device to 'cpu' or 'cuda' if available
device = torch.device('cpu')
# Parameters
sample_rate = 44100 # 44.1kHz stereo sounds
n_fft = 4096 # FFT size
hop_length = 2048 # Hop length for STFT
duration = 5 # Duration of the sound files (5 seconds)
n_channels = 2 # Stereo sound
output_time_frames = duration * sample_rate // hop_length # Number of time frames in the spectrogram
stft_transform = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, win_length=n_fft)
image_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1]
])
# Image Encoder (for the Generator)
class ImageEncoder(nn.Module):
def __init__(self):
super(ImageEncoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.fc = nn.Linear(512 * 16 * 16, 512)
def forward(self, x):
x = self.encoder(x)
x = x.view(x.size(0), -1)
return self.fc(x)
# Sound Decoder (for the Generator)
class SoundDecoder(nn.Module):
def __init__(self, output_time_frames):
super(SoundDecoder, self).__init__()
self.fc = nn.Linear(512, 512 * 8 * 8)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, n_channels, kernel_size=4, stride=2, padding=1),
)
# Modify the upsample to exactly match the real spectrogram size (108 time frames)
self.upsample = nn.Upsample(size=(n_fft // 2 + 1, 108), mode='bilinear', align_corners=True)
def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), 512, 8, 8)
x = self.decoder(x)
x = self.upsample(x)
# Debugging shape
print(f'Generated spectrogram shape: {x.shape}')
return x
# Generator model
class Generator(nn.Module):
def __init__(self, output_time_frames):
super(Generator, self).__init__()
self.encoder = ImageEncoder()
self.decoder = SoundDecoder(output_time_frames)
def forward(self, img):
# Debugging: Image encoder
encoded_features = self.encoder(img)
print(f"Encoded features shape (from Image Encoder): {encoded_features.shape}")
# Debugging: Sound decoder
generated_spectrogram = self.decoder(encoded_features)
print(f"Generated spectrogram shape (from Sound Decoder): {generated_spectrogram.shape}")
return generated_spectrogram
# Function to generate and save audio from a test image using the pre-trained GAN model
def test_model(generator, test_img_path, output_audio_path, device):
# Load and preprocess test image
test_img = Image.open(test_img_path).convert('RGB')
test_img = image_transform(test_img).unsqueeze(0).to(device) # Add batch dimension
# Generate sound spectrogram from the image
with torch.no_grad(): # Disable gradient calculation for inference
generated_spectrogram = generator(test_img)
# Debugging: Check generated spectrogram shape
print(f"Generated spectrogram shape: {generated_spectrogram.shape}")
# Convert the generated spectrogram to audio
generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu()) # Remove batch dimension
# Load the pre-trained GAN model
def load_gan_model(generator, model_path, device):
generator.load_state_dict(torch.load(model_path, map_location=device))
generator.eval() # Set the model to evaluation mode
return generator
# Generator model class definitions remain the same as in your original code.
# Convert magnitude-only spectrogram to complex format by assuming zero phase
def magnitude_to_complex_spectrogram(magnitude_spectrogram):
zero_phase = torch.zeros_like(magnitude_spectrogram)
complex_spectrogram = torch.stack([magnitude_spectrogram, zero_phase], dim=-1)
return complex_spectrogram
# Convert spectrogram back to audio using inverse STFT
def spectrogram_to_audio(magnitude_spectrogram):
magnitude_spectrogram = torch.expm1(magnitude_spectrogram)
complex_spectrogram = magnitude_to_complex_spectrogram(magnitude_spectrogram)
audio = torch.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length)
return audio
# Function to generate audio from an uploaded image
def generate_audio_from_image(image):
test_img = image_transform(image).unsqueeze(0).to(device) # Preprocess image
# Generate sound spectrogram from the image using the loaded generator
with torch.no_grad():
generated_spectrogram = generator(test_img)
# Convert the generated spectrogram to audio
generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu())
# Convert audio tensor to numpy and return it for Gradio to handle
return (sample_rate, generated_audio.numpy())
# Gradio Interface
def main():
global generator # Declare the generator object globally
# Instantiate your Generator model
generator = Generator(output_time_frames).to(device)
# Load the pre-trained model
model_path = '/path/to/your/model/gan_model_100e_16b.pth' # Change this path
generator = load_gan_model(generator, model_path, device)
# Gradio interface: allow users to upload an image and generate audio
iface = gr.Interface(fn=generate_audio_from_image, inputs=gr.Image(type="pil"), outputs=gr.Audio(type="numpy", label="Generated Audio"))
iface.launch()
if __name__ == "__main__":
main()
|