musdfakoc commited on
Commit
d076b8a
·
1 Parent(s): dfc82e0

Add model files

Browse files
Files changed (3) hide show
  1. app.py +190 -0
  2. gan_model.pth +3 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
+ import torchaudio.transforms as T
7
+ from torch import nn, optim
8
+ import torchvision.transforms as transforms
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from PIL import Image
11
+ import os
12
+
13
+ # Set device to 'cpu' or 'cuda' if available
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+
16
+ # Parameters
17
+ sample_rate = 44100 # 44.1kHz stereo sounds
18
+ n_fft = 4096 # FFT size
19
+ hop_length = 2048 # Hop length for STFT
20
+ duration = 5 # Duration of the sound files (5 seconds)
21
+ n_channels = 2 # Stereo sound
22
+ output_time_frames = duration * sample_rate // hop_length # Number of time frames in the spectrogram
23
+
24
+ stft_transform = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, win_length=n_fft)
25
+
26
+ image_transform = transforms.Compose([
27
+ transforms.Resize((256, 256)),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1]
30
+ ])
31
+
32
+ # Image Encoder (for the Generator)
33
+ class ImageEncoder(nn.Module):
34
+ def __init__(self):
35
+ super(ImageEncoder, self).__init__()
36
+ self.encoder = nn.Sequential(
37
+ nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
38
+ nn.BatchNorm2d(64),
39
+ nn.ReLU(),
40
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
41
+ nn.BatchNorm2d(128),
42
+ nn.ReLU(),
43
+ nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
44
+ nn.BatchNorm2d(256),
45
+ nn.ReLU(),
46
+ nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
47
+ nn.BatchNorm2d(512),
48
+ nn.ReLU()
49
+ )
50
+ self.fc = nn.Linear(512 * 16 * 16, 512)
51
+
52
+ def forward(self, x):
53
+ x = self.encoder(x)
54
+ x = x.view(x.size(0), -1)
55
+ return self.fc(x)
56
+
57
+
58
+ # Sound Decoder (for the Generator)
59
+ class SoundDecoder(nn.Module):
60
+ def __init__(self, output_time_frames):
61
+ super(SoundDecoder, self).__init__()
62
+ self.fc = nn.Linear(512, 512 * 8 * 8)
63
+
64
+ self.decoder = nn.Sequential(
65
+ nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
66
+ nn.BatchNorm2d(256),
67
+ nn.ReLU(),
68
+ nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
69
+ nn.BatchNorm2d(128),
70
+ nn.ReLU(),
71
+ nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
72
+ nn.BatchNorm2d(64),
73
+ nn.ReLU(),
74
+ nn.ConvTranspose2d(64, n_channels, kernel_size=4, stride=2, padding=1),
75
+ )
76
+
77
+ # Modify the upsample to exactly match the real spectrogram size (108 time frames)
78
+ self.upsample = nn.Upsample(size=(n_fft // 2 + 1, 108), mode='bilinear', align_corners=True)
79
+
80
+ def forward(self, x):
81
+ x = self.fc(x)
82
+ x = x.view(x.size(0), 512, 8, 8)
83
+ x = self.decoder(x)
84
+ x = self.upsample(x)
85
+ # Debugging shape
86
+ print(f'Generated spectrogram shape: {x.shape}')
87
+ return x
88
+
89
+ # Generator model
90
+ class Generator(nn.Module):
91
+ def __init__(self, output_time_frames):
92
+ super(Generator, self).__init__()
93
+ self.encoder = ImageEncoder()
94
+ self.decoder = SoundDecoder(output_time_frames)
95
+
96
+ def forward(self, img):
97
+ # Debugging: Image encoder
98
+ encoded_features = self.encoder(img)
99
+ print(f"Encoded features shape (from Image Encoder): {encoded_features.shape}")
100
+
101
+ # Debugging: Sound decoder
102
+ generated_spectrogram = self.decoder(encoded_features)
103
+ print(f"Generated spectrogram shape (from Sound Decoder): {generated_spectrogram.shape}")
104
+
105
+ return generated_spectrogram
106
+
107
+
108
+
109
+ # Function to save audio
110
+ def save_audio(audio, path, sample_rate=44100):
111
+ # Ensure audio is in stereo by checking the channels
112
+ if audio.dim() == 1:
113
+ audio = audio.unsqueeze(0).repeat(2, 1) # Convert mono to stereo
114
+ elif audio.size(0) == 1:
115
+ audio = audio.repeat(2, 1) # Convert mono to stereo
116
+
117
+ # Save audio to a file
118
+ torchaudio.save(path, audio, sample_rate)
119
+
120
+
121
+ # Function to generate and save audio from a test image using the pre-trained GAN model
122
+ def test_model(generator, test_img_path, output_audio_path, device):
123
+ # Load and preprocess test image
124
+ test_img = Image.open(test_img_path).convert('RGB')
125
+ test_img = image_transform(test_img).unsqueeze(0).to(device) # Add batch dimension
126
+
127
+ # Generate sound spectrogram from the image
128
+ with torch.no_grad(): # Disable gradient calculation for inference
129
+ generated_spectrogram = generator(test_img)
130
+
131
+ # Debugging: Check generated spectrogram shape
132
+ print(f"Generated spectrogram shape: {generated_spectrogram.shape}")
133
+
134
+ # Convert the generated spectrogram to audio
135
+ generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu()) # Remove batch dimension
136
+
137
+ # Save the generated audio
138
+ save_audio(generated_audio, output_audio_path)
139
+
140
+ print(f"Generated audio saved to {output_audio_path}")
141
+
142
+ # Load the pre-trained GAN model
143
+ def load_gan_model(generator, model_path, device):
144
+ generator.load_state_dict(torch.load(model_path, map_location=device))
145
+ generator.eval() # Set the model to evaluation mode
146
+ return generator
147
+
148
+ # Convert magnitude-only spectrogram to complex format by assuming zero phase
149
+ def magnitude_to_complex_spectrogram(magnitude_spectrogram):
150
+ zero_phase = torch.zeros_like(magnitude_spectrogram)
151
+ complex_spectrogram = torch.stack([magnitude_spectrogram, zero_phase], dim=-1)
152
+ return complex_spectrogram
153
+
154
+ # Convert spectrogram back to audio using inverse STFT
155
+ def spectrogram_to_audio(magnitude_spectrogram):
156
+ magnitude_spectrogram = torch.expm1(magnitude_spectrogram)
157
+ complex_spectrogram = magnitude_to_complex_spectrogram(magnitude_spectrogram)
158
+ audio = torch.istft(complex_spectrogram, n_fft=n_fft, hop_length=hop_length)
159
+ return audio
160
+
161
+ # Function to generate audio from an uploaded image
162
+ def generate_audio_from_image(image):
163
+ test_img = image_transform(image).unsqueeze(0).to(device) # Preprocess image
164
+
165
+ # Generate sound spectrogram from the image using the loaded generator
166
+ with torch.no_grad():
167
+ generated_spectrogram = generator(test_img)
168
+
169
+ # Convert the generated spectrogram to audio
170
+ generated_audio = spectrogram_to_audio(generated_spectrogram.squeeze(0).cpu())
171
+
172
+ # Convert audio tensor to numpy and return it for Gradio to handle
173
+ return generated_audio.numpy(), sample_rate
174
+
175
+ # Gradio Interface
176
+ def main():
177
+ global generator # Declare the generator object globally
178
+ # Instantiate your Generator model
179
+ generator = Generator(output_time_frames).to(device)
180
+
181
+ # Load the pre-trained model
182
+ model_path = '/Users/mustafakoc/Desktop/Workshop/istinye/local_intelligence/gan_model.pth' # Change this path
183
+ generator = load_gan_model(generator, model_path, device)
184
+
185
+ # Gradio interface: allow users to upload an image and generate audio
186
+ iface = gr.Interface(fn=generate_audio_from_image, inputs=gr.Image(type="pil"), outputs=gr.Audio(type="numpy", label="Generated Audio"))
187
+ iface.launch()
188
+
189
+ if __name__ == "__main__":
190
+ main()
gan_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f909a44210255efb3f4d85e91f28bdbcab9c9d098eb8c8bca61d6df41fa296d7
3
+ size 357763072
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ gradio
4
+ Pillow
5
+ torchvision