import torch import torchaudio import gradio as gr from PIL import Image import torchvision.transforms as transforms import torchaudio.transforms as T from torch import nn, optim import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader from PIL import Image import os # Set device to 'cpu' or 'cuda' if available device = torch.device('cpu') # Parameters sample_rate = 44100 # 44.1kHz stereo sounds n_fft = 4096 # FFT size hop_length = 2048 # Hop length for STFT duration = 5 # Duration of the sound files (5 seconds) n_channels = 2 # Stereo sound output_time_frames = duration * sample_rate // hop_length # Number of time frames in the spectrogram stft_transform = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, win_length=n_fft) image_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1] ]) # Image Encoder (for the Generator) class ImageEncoder(nn.Module): def __init__(self): super(ImageEncoder, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(512), nn.ReLU() ) self.fc = nn.Linear(512 * 16 * 16, 512) def forward(self, x): x = self.encoder(x) x = x.view(x.size(0), -1) return self.fc(x) # Sound Decoder (for the Generator) class SoundDecoder(nn.Module): def __init__(self, output_time_frames): super(SoundDecoder, self).__init__() self.fc = nn.Linear(512, 512 * 8 * 8) self.decoder = nn.Sequential( nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.ConvTranspose2d(64, n_channels, kernel_size=4, stride=2, padding=1), ) # Modify the upsample to exactly match the real spectrogram size (108 time frames) self.upsample = nn.Upsample(size=(n_fft // 2 + 1, 108), mode='bilinear', align_corners=True) def forward(self, x): x = self.fc(x) x = x.view(x.size(0), 512, 8, 8) x = self.decoder(x) x = self.upsample(x) # Debugging shape print(f'Generated spectrogram shape: {x.shape}') return x # Generator model class Generator(nn.Module): def __init__(self, output_time_frames): super(Generator, self).__init__() self.encoder = ImageEncoder() self.decoder = SoundDecoder(output_time_frames) def forward(self, img): # Debugging: Image encoder encoded_features = self.encoder(img) print(f"Encoded features shape (from Image Encoder): {encoded_features.shape}") # Debugging: Sound decoder generated_spectrogram = self.decoder(encoded_features) print(f"Generated spectrogram shape (from Sound Decoder): {generated_spectrogram.shape}") return generated_spectrogram # Function to generate and save audio from a test image using the pre-trained GAN model def test_model(generator, test_img_path, output_audio_path, device): # Load and preprocess test image test_img = Image.open(test_img_path).convert('RGB') test_img = image_transform(test_img).unsqueeze(0).to(device) # Add batch dimension # Generate sound spectrogram from the image with torch.no_grad(): # Disable gradient calculation for inference generated_spectrogram = generator(test_img) # Debugging: Check generated spectrogram shape print(f"Generated spectrogram shape: {generated_spectrogram.shape}") # Convert the generated spectrogram to audio generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu()) # Remove batch dimension print(f"Generated audio saved to {output_audio_path}") # Load the pre-trained GAN model def load_gan_model(generator, model_path, device): generator.load_state_dict(torch.load(model_path, map_location=device)) generator.eval() # Set the model to evaluation mode return generator def magnitude_to_complex_spectrogram(magnitude_spectrogram): # Create a zero-phase tensor with the same shape as the magnitude spectrogram zero_phase = torch.zeros_like(magnitude_spectrogram) # Create a complex-valued spectrogram using the magnitude and zero phase complex_spectrogram = torch.complex(magnitude_spectrogram, zero_phase) return complex_spectrogram def spectrogram_to_audio(magnitude_spectrogram): # Convert magnitude-only spectrogram to complex format complex_spectrogram = magnitude_to_complex_spectrogram(magnitude_spectrogram) # Provide a rectangular window to suppress the warning window = torch.ones(n_fft, device=complex_spectrogram.device) # Inverse STFT to convert the spectrogram back to audio audio = torch.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length, window=window) return audio import numpy as np def generate_audio_from_image(image): if image is None: raise ValueError("The uploaded image is 'None'. Please check the Gradio input.") # Ensure the image is in the right format print(f"Image received: {type(image)}") # Debugging: Check if image is received test_img = image_transform(image).unsqueeze(0).to(device) # Preprocess image # Generate sound spectrogram from the image using the loaded generator with torch.no_grad(): generated_spectrogram = generator(test_img) # Convert the generated spectrogram to audio generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu()) # Ensure the audio is a NumPy array and properly formatted generated_audio = generated_audio.numpy() # Normalize the audio to fit between -1 and 1 for proper playback max_value = np.abs(generated_audio).max() if max_value > 0: generated_audio = generated_audio / max_value # Convert to the required format (e.g., float32) generated_audio = generated_audio.astype(np.float32) # Debug: Print the shape and type of the generated audio print(f"Generated audio shape: {generated_audio.shape}, type: {generated_audio.dtype}") return generated_audio, sample_rate # Gradio Interface def main(): global generator # Declare the generator object globally # Instantiate your Generator model generator = Generator(output_time_frames).to(device) # Load the pre-trained model model_path = './gan_model.pth' # Ensure the model is in the correct relative path generator = load_gan_model(generator, model_path, device) iface = gr.Interface(fn=generate_audio_from_image, inputs=gr.Image(type="pil"), outputs=gr.Audio(type="numpy", label="Generated Audio")) iface.launch() if __name__ == "__main__": main()