Vedansh-7 commited on
Commit
ff9cac6
·
1 Parent(s): dd9af11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -79
app.py CHANGED
@@ -125,99 +125,91 @@ class DiffusionModel(nn.Module):
125
  self.model = model
126
  self.timesteps = timesteps
127
 
128
- # More conservative noise schedule
129
- scale = 1000 / timesteps
130
  beta_start = 0.0001
131
  beta_end = 0.02
132
- self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)**1.5
133
 
134
  self.alphas = 1. - self.betas
135
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
 
 
136
  self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
137
  self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
138
-
139
- @torch.no_grad()
140
- def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
141
- # Initialize with reduced noise scale
142
- x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.7
143
-
144
- # Convert labels if needed
145
- if labels.ndim == 1:
146
- labels_one_hot = torch.zeros(num_images, num_classes, device=device)
147
- labels_one_hot[torch.arange(num_images), labels] = 1
148
- labels = labels_one_hot
149
-
150
- for t in reversed(range(timesteps)):
151
- if cancel_event.is_set():
152
- return None
153
-
154
- t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
155
 
156
- # Predict noise with model
157
- pred_noise = self.model(x_t, labels, t_tensor.float())
 
158
 
159
- # Get current alpha values
160
- alpha_t = self.alphas[t]
161
- alpha_bar_t = self.alphas_cumprod[t]
162
- alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
 
 
 
 
 
 
 
 
163
 
164
- # Calculate predicted x0 with more stable equations
165
- pred_x0 = (x_t - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t)
166
 
167
- # Direction pointing to x_t with reduced noise impact
168
- pred_dir = torch.sqrt(1 - alpha_bar_t_prev) * pred_noise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- # Dynamic noise scaling based on timestep
171
- if t > 0:
172
- noise_scale = 0.3 * (t / timesteps) # Reduce noise as we get closer to final image
173
- noise = torch.randn_like(x_t) * noise_scale
174
- else:
175
- noise = torch.zeros_like(x_t)
176
 
177
- # Update x_t with more stable combination
178
- x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 + pred_dir + noise
 
179
 
180
- # Progress callback
181
- if progress_callback:
182
- progress_callback((timesteps - t) / timesteps)
183
-
184
- # Enhanced normalization with contrast adjustment
185
- x_t = torch.clamp(x_t, -1, 1)
186
- x_t = (x_t + 1) / 2 # Scale to [0,1]
187
-
188
- # Post-processing directly in the tensor
189
- x_t = self._post_process(x_t)
190
-
191
- return x_t
192
-
193
- def _post_process(self, image_tensor):
194
- """Apply simple post-processing to reduce noise"""
195
- # Contrast adjustment
196
- mean_val = image_tensor.mean()
197
- image_tensor = (image_tensor - mean_val) * 1.2 + mean_val
198
-
199
- # Mild Gaussian blur (implemented as depthwise convolution)
200
- if hasattr(self, '_blur_kernel'):
201
- blur_kernel = self._blur_kernel.to(image_tensor.device)
202
- else:
203
- blur_kernel = torch.tensor([
204
- [0.05, 0.1, 0.05],
205
- [0.1, 0.4, 0.1],
206
- [0.05, 0.1, 0.05]
207
- ], dtype=torch.float32).view(1, 1, 3, 3).repeat(3, 1, 1, 1)
208
- self._blur_kernel = blur_kernel
209
 
210
- # Apply blur to each channel
211
- padding = (1, 1, 1, 1)
212
- image_tensor = torch.nn.functional.conv2d(
213
- image_tensor.permute(0, 3, 1, 2), # NHWC to NCHW
214
- blur_kernel,
215
- padding=1,
216
- groups=3
217
- ).permute(0, 2, 3, 1) # Back to NHWC
218
-
219
- return torch.clamp(image_tensor, 0, 1)
220
-
221
 
222
  def load_model(model_path, device):
223
  unet = UNet(num_classes=NUM_CLASSES).to(device)
@@ -295,7 +287,6 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
295
  with torch.no_grad():
296
  images = loaded_model.sample(
297
  num_images=num_images,
298
- timesteps=int(TIMESTEPS * 1.5), # More timesteps for cleaner images
299
  img_size=IMG_SIZE,
300
  num_classes=NUM_CLASSES,
301
  labels=labels,
 
125
  self.model = model
126
  self.timesteps = timesteps
127
 
128
+ # Improved noise schedule
 
129
  beta_start = 0.0001
130
  beta_end = 0.02
131
+ self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
132
 
133
  self.alphas = 1. - self.betas
134
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
135
+
136
+ # Pre-calculate values for sampling
137
  self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
138
  self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
139
+ self.register_buffer('sqrt_recip_alphas', torch.sqrt(1. / self.alphas))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ # Calculations for posterior q(x_{t-1} | x_t, x_0)
142
+ posterior_variance = self.betas * (1. - self.alphas_cumprod[:-1]) / (1. - self.alphas_cumprod[1:])
143
+ self.register_buffer('posterior_variance', posterior_variance)
144
 
145
+ # Blur kernel for post-processing
146
+ self.register_buffer('blur_kernel', torch.tensor([
147
+ [0.05, 0.1, 0.05],
148
+ [0.1, 0.4, 0.1],
149
+ [0.05, 0.1, 0.05]
150
+ ], dtype=torch.float32).view(1, 1, 3, 3).repeat(3, 1, 1, 1))
151
+
152
+ @torch.no_grad()
153
+ def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
154
+ """Generate samples from the model"""
155
+ shape = (num_images, 3, img_size, img_size)
156
+ x_t = torch.randn(shape, device=device) * 0.7 # Slightly reduced initial noise
157
 
158
+ if labels.ndim == 1:
159
+ labels = torch.zeros(num_images, num_classes, device=device).scatter_(1, labels.unsqueeze(1), 1)
160
 
161
+ for t in reversed(range(self.timesteps)):
162
+ if cancel_event.is_set():
163
+ return None
164
+
165
+ t_batch = torch.full((num_images,), t, device=device, dtype=torch.long)
166
+ pred_noise = self.model(x_t, labels, t_batch.float())
167
+
168
+ alpha_bar_t = self.alphas_cumprod[t]
169
+ alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
170
+
171
+ # Calculate predicted x0
172
+ pred_x0 = (x_t - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t)
173
+
174
+ # Calculate direction pointing to x_t
175
+ pred_dir = torch.sqrt(1 - alpha_bar_t_prev) * pred_noise
176
+
177
+ # Dynamic noise scaling
178
+ if t > 0:
179
+ noise_scale = 0.3 * (t / self.timesteps)
180
+ noise = torch.randn_like(x_t) * noise_scale
181
+ else:
182
+ noise = torch.zeros_like(x_t)
183
+
184
+ # Update x_t
185
+ x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 + pred_dir + noise
186
+
187
+ if progress_callback:
188
+ progress_callback((self.timesteps - t) / self.timesteps)
189
+
190
+ # Post-processing
191
+ x_t = self._post_process(x_t)
192
+ return x_t
193
+
194
+ def _post_process(self, images):
195
+ """Apply post-processing to reduce noise and enhance contrast"""
196
+ # Normalize to [0,1]
197
+ images = torch.clamp(images, -1, 1)
198
+ images = (images + 1) / 2
199
 
200
+ # Apply mild blur (convert NHWC to NCHW for conv2d)
201
+ if images.dim() == 4 and images.shape[-1] != 3: # NCHW format
202
+ images = images.permute(0, 2, 3, 1)
 
 
 
203
 
204
+ x = images.permute(0, 3, 1, 2) # NHWC to NCHW
205
+ x = torch.nn.functional.conv2d(x, self.blur_kernel, padding=1, groups=3)
206
+ images = x.permute(0, 2, 3, 1) # NCHW to NHWC
207
 
208
+ # Contrast adjustment
209
+ mean_val = images.mean(dim=(1,2,3), keepdim=True)
210
+ images = (images - mean_val) * 1.2 + mean_val
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
+ return torch.clamp(images, 0, 1)
 
 
 
 
 
 
 
 
 
 
213
 
214
  def load_model(model_path, device):
215
  unet = UNet(num_classes=NUM_CLASSES).to(device)
 
287
  with torch.no_grad():
288
  images = loaded_model.sample(
289
  num_images=num_images,
 
290
  img_size=IMG_SIZE,
291
  num_classes=NUM_CLASSES,
292
  labels=labels,