Spaces:
Sleeping
Sleeping
import torch | |
import torchaudio | |
import gradio as gr | |
from PIL import Image | |
import torchvision.transforms as transforms | |
import torchaudio.transforms as T | |
from torch import nn, optim | |
import torchvision.transforms as transforms | |
from torch.utils.data import Dataset, DataLoader | |
from PIL import Image | |
import os | |
import numpy as np | |
# Set device to 'cpu' or 'cuda' if available | |
device = torch.device('cpu') | |
# Parameters | |
sample_rate = 44100 # 44.1kHz stereo sounds | |
n_fft = 4096 # FFT size | |
hop_length = 2048 # Hop length for STFT | |
duration = 5 # Duration of the sound files (5 seconds) | |
n_channels = 2 # Stereo sound | |
output_time_frames = duration * sample_rate // hop_length # Number of time frames in the spectrogram | |
stft_transform = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, win_length=n_fft) | |
image_transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1] | |
]) | |
# Image Encoder (for the Generator) | |
class ImageEncoder(nn.Module): | |
def __init__(self): | |
super(ImageEncoder, self).__init__() | |
self.encoder = nn.Sequential( | |
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1), | |
nn.BatchNorm2d(64), | |
nn.ReLU(), | |
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), | |
nn.BatchNorm2d(128), | |
nn.ReLU(), | |
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), | |
nn.BatchNorm2d(256), | |
nn.ReLU(), | |
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), | |
nn.BatchNorm2d(512), | |
nn.ReLU() | |
) | |
self.fc = nn.Linear(512 * 16 * 16, 512) | |
def forward(self, x): | |
x = self.encoder(x) | |
x = x.view(x.size(0), -1) | |
return self.fc(x) | |
# Sound Decoder (for the Generator) | |
class SoundDecoder(nn.Module): | |
def __init__(self, output_time_frames): | |
super(SoundDecoder, self).__init__() | |
self.fc = nn.Linear(512, 512 * 8 * 8) | |
self.decoder = nn.Sequential( | |
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), | |
nn.BatchNorm2d(256), | |
nn.ReLU(), | |
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), | |
nn.BatchNorm2d(128), | |
nn.ReLU(), | |
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), | |
nn.BatchNorm2d(64), | |
nn.ReLU(), | |
nn.ConvTranspose2d(64, n_channels, kernel_size=4, stride=2, padding=1), | |
) | |
# Modify the upsample to exactly match the real spectrogram size (108 time frames) | |
self.upsample = nn.Upsample(size=(n_fft // 2 + 1, 108), mode='bilinear', align_corners=True) | |
def forward(self, x): | |
x = self.fc(x) | |
x = x.view(x.size(0), 512, 8, 8) | |
x = self.decoder(x) | |
x = self.upsample(x) | |
# Debugging shape | |
print(f'Generated spectrogram shape: {x.shape}') | |
return x | |
# Generator model | |
class Generator(nn.Module): | |
def __init__(self, output_time_frames): | |
super(Generator, self).__init__() | |
self.encoder = ImageEncoder() | |
self.decoder = SoundDecoder(output_time_frames) | |
def forward(self, img): | |
# Debugging: Image encoder | |
encoded_features = self.encoder(img) | |
print(f"Encoded features shape (from Image Encoder): {encoded_features.shape}") | |
# Debugging: Sound decoder | |
generated_spectrogram = self.decoder(encoded_features) | |
print(f"Generated spectrogram shape (from Sound Decoder): {generated_spectrogram.shape}") | |
return generated_spectrogram | |
# Function to generate and save audio from a test image using the pre-trained GAN model | |
def test_model(generator, test_img_path, output_audio_path, device): | |
# Load and preprocess test image | |
test_img = Image.open(test_img_path).convert('RGB') | |
test_img = image_transform(test_img).unsqueeze(0).to(device) # Add batch dimension | |
# Generate sound spectrogram from the image | |
with torch.no_grad(): # Disable gradient calculation for inference | |
generated_spectrogram = generator(test_img) | |
# Debugging: Check generated spectrogram shape | |
print(f"Generated spectrogram shape: {generated_spectrogram.shape}") | |
# Convert the generated spectrogram to audio | |
generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu()) # Remove batch dimension | |
# Load the pre-trained GAN model | |
def load_gan_model(generator, model_path, device): | |
generator.load_state_dict(torch.load(model_path, map_location=device)) | |
generator.eval() # Set the model to evaluation mode | |
return generator | |
def magnitude_to_complex_spectrogram(magnitude_spectrogram): | |
# Create a zero-phase tensor with the same shape as the magnitude spectrogram | |
zero_phase = torch.zeros_like(magnitude_spectrogram) | |
# Create a complex-valued spectrogram using the magnitude and zero phase | |
complex_spectrogram = torch.complex(magnitude_spectrogram, zero_phase) | |
return complex_spectrogram | |
def spectrogram_to_audio(magnitude_spectrogram): | |
# Convert magnitude-only spectrogram to complex format | |
complex_spectrogram = magnitude_to_complex_spectrogram(magnitude_spectrogram) | |
# Provide a rectangular window to suppress the warning | |
window = torch.ones(n_fft, device=complex_spectrogram.device) | |
# Inverse STFT to convert the spectrogram back to audio | |
audio = torch.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length, window=window) | |
return audio | |
def generate_audio_from_image(image): | |
if image is None: | |
raise ValueError("The uploaded image is 'None'. Please check the Gradio input.") | |
# Ensure the image is in the right format | |
test_img = image_transform(image).unsqueeze(0).to(device) # Preprocess image | |
# Generate sound spectrogram from the image using the loaded generator | |
with torch.no_grad(): | |
generated_spectrogram = generator(test_img) | |
# Convert the generated spectrogram to audio | |
generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu()) | |
# Ensure the audio is a NumPy array and properly formatted | |
generated_audio = generated_audio.numpy() | |
# Normalize the audio to fit between -1 and 1 for proper playback | |
max_value = np.abs(generated_audio).max() | |
if max_value > 0: | |
generated_audio = generated_audio / max_value | |
# Convert to the required format (e.g., float32) | |
generated_audio = generated_audio.astype(np.float32) | |
# Transpose to (samples, channels) for stereo | |
generated_audio = generated_audio.T | |
# Return the audio and the sample rate (ensure sample rate is an integer) | |
return generated_audio, int(sample_rate) | |
# Gradio Interface | |
def main(): | |
global generator # Declare the generator object globally | |
# Instantiate your Generator model | |
generator = Generator(output_time_frames).to(device) | |
# Load the pre-trained model | |
model_path = './gan_model.pth' # Ensure the model is in the correct relative path | |
generator = load_gan_model(generator, model_path, device) | |
iface = gr.Interface(fn=generate_audio_from_image, inputs=gr.Image(type="pil"), outputs=gr.Audio(type="numpy", label="Generated Audio")) | |
iface.launch() | |
if __name__ == "__main__": | |
main() | |