musdfakoc's picture
Update app.py
f36caa0 verified
raw
history blame
6.67 kB
import torch
import torchaudio
import gradio as gr
from PIL import Image
import torchvision.transforms as transforms
import torchaudio.transforms as T
from torch import nn, optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np
# Set device to 'cpu' or 'cuda' if available
device = torch.device('cpu')
# Parameters
sample_rate = 44100 # 44.1kHz stereo sounds
n_fft = 4096 # FFT size
hop_length = 2048 # Hop length for STFT
duration = 5 # Duration of the sound files (5 seconds)
n_channels = 2 # Stereo sound
output_time_frames = duration * sample_rate // hop_length # Number of time frames in the spectrogram
stft_transform = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, win_length=n_fft)
image_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1]
])
# Image Encoder (for the Generator)
class ImageEncoder(nn.Module):
def __init__(self):
super(ImageEncoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.fc = nn.Linear(512 * 16 * 16, 512)
def forward(self, x):
x = self.encoder(x)
x = x.view(x.size(0), -1)
return self.fc(x)
# Sound Decoder (for the Generator)
class SoundDecoder(nn.Module):
def __init__(self, output_time_frames):
super(SoundDecoder, self).__init__()
self.fc = nn.Linear(512, 512 * 8 * 8)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, n_channels, kernel_size=4, stride=2, padding=1),
)
# Modify the upsample to exactly match the real spectrogram size (108 time frames)
self.upsample = nn.Upsample(size=(n_fft // 2 + 1, 108), mode='bilinear', align_corners=True)
def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), 512, 8, 8)
x = self.decoder(x)
x = self.upsample(x)
# Debugging shape
print(f'Generated spectrogram shape: {x.shape}')
return x
# Generator model
class Generator(nn.Module):
def __init__(self, output_time_frames):
super(Generator, self).__init__()
self.encoder = ImageEncoder()
self.decoder = SoundDecoder(output_time_frames)
def forward(self, img):
# Debugging: Image encoder
encoded_features = self.encoder(img)
print(f"Encoded features shape (from Image Encoder): {encoded_features.shape}")
# Debugging: Sound decoder
generated_spectrogram = self.decoder(encoded_features)
print(f"Generated spectrogram shape (from Sound Decoder): {generated_spectrogram.shape}")
return generated_spectrogram
# Function to generate and save audio from a test image using the pre-trained GAN model
def test_model(generator, test_img_path, output_audio_path, device):
# Load and preprocess test image
test_img = Image.open(test_img_path).convert('RGB')
test_img = image_transform(test_img).unsqueeze(0).to(device) # Add batch dimension
# Generate sound spectrogram from the image
with torch.no_grad(): # Disable gradient calculation for inference
generated_spectrogram = generator(test_img)
# Debugging: Check generated spectrogram shape
print(f"Generated spectrogram shape: {generated_spectrogram.shape}")
# Convert the generated spectrogram to audio
generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu()) # Remove batch dimension
# Load the pre-trained GAN model
def load_gan_model(generator, model_path, device):
generator.load_state_dict(torch.load(model_path, map_location=device))
generator.eval() # Set the model to evaluation mode
return generator
# Generator model class definitions remain the same as in your original code.
# Convert magnitude-only spectrogram to complex format by assuming zero phase
def magnitude_to_complex_spectrogram(magnitude_spectrogram):
zero_phase = torch.zeros_like(magnitude_spectrogram)
complex_spectrogram = torch.stack([magnitude_spectrogram, zero_phase], dim=-1)
return complex_spectrogram
# Convert spectrogram back to audio using inverse STFT
def spectrogram_to_audio(magnitude_spectrogram):
magnitude_spectrogram = torch.expm1(magnitude_spectrogram)
complex_spectrogram = magnitude_to_complex_spectrogram(magnitude_spectrogram)
audio = torch.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length)
return audio
# Function to generate audio from an uploaded image
def generate_audio_from_image(image):
test_img = image_transform(image).unsqueeze(0).to(device) # Preprocess image
# Generate sound spectrogram from the image using the loaded generator
with torch.no_grad():
generated_spectrogram = generator(test_img)
# Convert the generated spectrogram to audio
generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu())
# Convert audio tensor to numpy and return it for Gradio to handle
return (sample_rate, generated_audio.numpy())
# Gradio Interface
def main():
global generator # Declare the generator object globally
# Instantiate your Generator model
generator = Generator(output_time_frames).to(device)
# Load the pre-trained model
model_path = '/path/to/your/model/gan_model_100e_16b.pth' # Change this path
generator = load_gan_model(generator, model_path, device)
# Gradio interface: allow users to upload an image and generate audio
iface = gr.Interface(fn=generate_audio_from_image, inputs=gr.Image(type="pil"), outputs=gr.Audio(type="numpy", label="Generated Audio"))
iface.launch()
if __name__ == "__main__":
main()