Vedansh-7 commited on
Commit
dd9af11
·
1 Parent(s): 9378bae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -67
app.py CHANGED
@@ -127,74 +127,98 @@ class DiffusionModel(nn.Module):
127
 
128
  # More conservative noise schedule
129
  scale = 1000 / timesteps
130
- beta_start = scale * 0.0001
131
- beta_end = scale * 0.02
132
- self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
133
 
134
  self.alphas = 1. - self.betas
135
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
136
  self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
137
  self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
138
 
139
- @torch.no_grad()
140
- def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
141
- # Initialize with standard normal distribution (scale=1.0)
142
- x_t = torch.randn((num_images, 3, img_size, img_size), device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- if labels.ndim == 1:
145
- labels_one_hot = torch.zeros(num_images, num_classes, device=device)
146
- labels_one_hot[torch.arange(num_images), labels] = 1
147
- labels = labels_one_hot
 
 
 
 
 
 
 
 
 
 
 
148
  else:
149
- labels = labels.float().to(device)
150
-
151
- for t in reversed(range(timesteps)):
152
- if cancel_event.is_set():
153
- return None
154
-
155
- t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
156
-
157
- # Predict noise with model
158
- pred_noise = self.model(x_t, labels, t_tensor.float())
159
-
160
- # Get current alpha values
161
- alpha_t = self.alphas[t]
162
- alpha_bar_t = self.alphas_cumprod[t]
163
- alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
164
-
165
- # Calculate coefficients
166
- beta_t = self.betas[t]
167
- sqrt_recip_alpha_t = torch.sqrt(1.0 / alpha_t)
168
- sqrt_one_minus_alpha_bar_t = torch.sqrt(1.0 - alpha_bar_t)
169
-
170
- # Calculate predicted x0
171
- pred_x0 = (x_t - sqrt_one_minus_alpha_bar_t * pred_noise) * sqrt_recip_alpha_t
172
-
173
- # Calculate direction pointing to x_t
174
- pred_dir = torch.sqrt(1.0 - alpha_bar_t_prev) * pred_noise
175
-
176
- # Noise for next step
177
- if t > 0:
178
- noise = torch.randn_like(x_t) * 0.5
179
- else:
180
- noise = torch.zeros_like(x_t)
181
-
182
- # Update x_t with stability checks
183
- x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 + pred_dir + noise * torch.sqrt(beta_t)
184
 
185
- # Numerical stability check
186
- if torch.isnan(x_t).any() or torch.isinf(x_t).any():
187
- x_t = torch.randn_like(x_t) * 0.1
188
-
189
- if progress_callback:
190
- progress_callback((timesteps - t) / timesteps)
191
-
192
- # Gentle normalization
193
- x_t = (x_t - x_t.min()) / (x_t.max() - x_t.min() + 1e-8) # [0, 1]
194
- x_t = torch.clamp(x_t, 0, 1) # Final safety clamp
195
 
196
- return x_t
 
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  def load_model(model_path, device):
199
  unet = UNet(num_classes=NUM_CLASSES).to(device)
200
  diffusion_model = DiffusionModel(unet).to(device)
@@ -270,24 +294,30 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
270
 
271
  with torch.no_grad():
272
  images = loaded_model.sample(
273
- num_images=num_images,
274
- timesteps=TIMESTEPS,
275
- img_size=IMG_SIZE,
276
- num_classes=NUM_CLASSES,
277
- labels=labels,
278
- device=device,
279
- progress_callback=progress_callback
280
  )
281
-
282
  if images is None:
283
  return None, None
284
-
285
  processed_images = []
286
  for img in images:
287
- img_np = img.cpu().numpy().transpose(1, 2, 0)
 
 
288
  img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
289
  pil_img = Image.fromarray(img_np)
 
 
 
290
  processed_images.append(pil_img)
 
291
 
292
  if num_images == 1:
293
  return processed_images[0], processed_images
 
127
 
128
  # More conservative noise schedule
129
  scale = 1000 / timesteps
130
+ beta_start = 0.0001
131
+ beta_end = 0.02
132
+ self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)**1.5
133
 
134
  self.alphas = 1. - self.betas
135
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
136
  self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
137
  self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
138
 
139
+ @torch.no_grad()
140
+ def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
141
+ # Initialize with reduced noise scale
142
+ x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.7
143
+
144
+ # Convert labels if needed
145
+ if labels.ndim == 1:
146
+ labels_one_hot = torch.zeros(num_images, num_classes, device=device)
147
+ labels_one_hot[torch.arange(num_images), labels] = 1
148
+ labels = labels_one_hot
149
+
150
+ for t in reversed(range(timesteps)):
151
+ if cancel_event.is_set():
152
+ return None
153
+
154
+ t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
155
+
156
+ # Predict noise with model
157
+ pred_noise = self.model(x_t, labels, t_tensor.float())
158
 
159
+ # Get current alpha values
160
+ alpha_t = self.alphas[t]
161
+ alpha_bar_t = self.alphas_cumprod[t]
162
+ alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
163
+
164
+ # Calculate predicted x0 with more stable equations
165
+ pred_x0 = (x_t - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t)
166
+
167
+ # Direction pointing to x_t with reduced noise impact
168
+ pred_dir = torch.sqrt(1 - alpha_bar_t_prev) * pred_noise
169
+
170
+ # Dynamic noise scaling based on timestep
171
+ if t > 0:
172
+ noise_scale = 0.3 * (t / timesteps) # Reduce noise as we get closer to final image
173
+ noise = torch.randn_like(x_t) * noise_scale
174
  else:
175
+ noise = torch.zeros_like(x_t)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
+ # Update x_t with more stable combination
178
+ x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 + pred_dir + noise
 
 
 
 
 
 
 
 
179
 
180
+ # Progress callback
181
+ if progress_callback:
182
+ progress_callback((timesteps - t) / timesteps)
183
 
184
+ # Enhanced normalization with contrast adjustment
185
+ x_t = torch.clamp(x_t, -1, 1)
186
+ x_t = (x_t + 1) / 2 # Scale to [0,1]
187
+
188
+ # Post-processing directly in the tensor
189
+ x_t = self._post_process(x_t)
190
+
191
+ return x_t
192
+
193
+ def _post_process(self, image_tensor):
194
+ """Apply simple post-processing to reduce noise"""
195
+ # Contrast adjustment
196
+ mean_val = image_tensor.mean()
197
+ image_tensor = (image_tensor - mean_val) * 1.2 + mean_val
198
+
199
+ # Mild Gaussian blur (implemented as depthwise convolution)
200
+ if hasattr(self, '_blur_kernel'):
201
+ blur_kernel = self._blur_kernel.to(image_tensor.device)
202
+ else:
203
+ blur_kernel = torch.tensor([
204
+ [0.05, 0.1, 0.05],
205
+ [0.1, 0.4, 0.1],
206
+ [0.05, 0.1, 0.05]
207
+ ], dtype=torch.float32).view(1, 1, 3, 3).repeat(3, 1, 1, 1)
208
+ self._blur_kernel = blur_kernel
209
+
210
+ # Apply blur to each channel
211
+ padding = (1, 1, 1, 1)
212
+ image_tensor = torch.nn.functional.conv2d(
213
+ image_tensor.permute(0, 3, 1, 2), # NHWC to NCHW
214
+ blur_kernel,
215
+ padding=1,
216
+ groups=3
217
+ ).permute(0, 2, 3, 1) # Back to NHWC
218
+
219
+ return torch.clamp(image_tensor, 0, 1)
220
+
221
+
222
  def load_model(model_path, device):
223
  unet = UNet(num_classes=NUM_CLASSES).to(device)
224
  diffusion_model = DiffusionModel(unet).to(device)
 
294
 
295
  with torch.no_grad():
296
  images = loaded_model.sample(
297
+ num_images=num_images,
298
+ timesteps=int(TIMESTEPS * 1.5), # More timesteps for cleaner images
299
+ img_size=IMG_SIZE,
300
+ num_classes=NUM_CLASSES,
301
+ labels=labels,
302
+ device=device,
303
+ progress_callback=progress_callback
304
  )
305
+
306
  if images is None:
307
  return None, None
308
+
309
  processed_images = []
310
  for img in images:
311
+ img_np = img.cpu().numpy()
312
+
313
+ # Convert to PIL with enhanced contrast
314
  img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
315
  pil_img = Image.fromarray(img_np)
316
+
317
+ # Apply additional PIL-based enhancements
318
+ pil_img = pil_img.filter(ImageFilter.SMOOTH_MORE)
319
  processed_images.append(pil_img)
320
+
321
 
322
  if num_images == 1:
323
  return processed_images[0], processed_images