Vedansh-7 commited on
Commit
bb3aba9
·
1 Parent(s): ff9cac6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -52
app.py CHANGED
@@ -125,70 +125,64 @@ class DiffusionModel(nn.Module):
125
  self.model = model
126
  self.timesteps = timesteps
127
 
128
- # Improved noise schedule
129
  beta_start = 0.0001
130
  beta_end = 0.02
131
  self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
132
 
133
  self.alphas = 1. - self.betas
134
- self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
135
-
136
- # Pre-calculate values for sampling
137
- self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
138
- self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
139
  self.register_buffer('sqrt_recip_alphas', torch.sqrt(1. / self.alphas))
140
-
141
- # Calculations for posterior q(x_{t-1} | x_t, x_0)
142
- posterior_variance = self.betas * (1. - self.alphas_cumprod[:-1]) / (1. - self.alphas_cumprod[1:])
143
- self.register_buffer('posterior_variance', posterior_variance)
144
-
145
- # Blur kernel for post-processing
146
- self.register_buffer('blur_kernel', torch.tensor([
147
- [0.05, 0.1, 0.05],
148
- [0.1, 0.4, 0.1],
149
- [0.05, 0.1, 0.05]
150
- ], dtype=torch.float32).view(1, 1, 3, 3).repeat(3, 1, 1, 1))
151
 
152
  @torch.no_grad()
153
  def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
154
- """Generate samples from the model"""
155
- shape = (num_images, 3, img_size, img_size)
156
- x_t = torch.randn(shape, device=device) * 0.7 # Slightly reduced initial noise
157
 
 
158
  if labels.ndim == 1:
159
  labels = torch.zeros(num_images, num_classes, device=device).scatter_(1, labels.unsqueeze(1), 1)
160
-
 
 
161
  for t in reversed(range(self.timesteps)):
162
  if cancel_event.is_set():
163
  return None
164
 
165
- t_batch = torch.full((num_images,), t, device=device, dtype=torch.long)
166
- pred_noise = self.model(x_t, labels, t_batch.float())
167
 
168
- alpha_bar_t = self.alphas_cumprod[t]
169
- alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
170
 
171
- # Calculate predicted x0
172
- pred_x0 = (x_t - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t)
 
 
173
 
174
- # Calculate direction pointing to x_t
175
- pred_dir = torch.sqrt(1 - alpha_bar_t_prev) * pred_noise
 
176
 
177
- # Dynamic noise scaling
178
  if t > 0:
179
- noise_scale = 0.3 * (t / self.timesteps)
180
- noise = torch.randn_like(x_t) * noise_scale
181
  else:
182
  noise = torch.zeros_like(x_t)
183
 
184
- # Update x_t
185
- x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 + pred_dir + noise
186
 
187
  if progress_callback:
188
  progress_callback((self.timesteps - t) / self.timesteps)
189
 
190
- # Post-processing
191
- x_t = self._post_process(x_t)
 
 
 
 
 
 
 
192
  return x_t
193
 
194
  def _post_process(self, images):
@@ -275,7 +269,8 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
275
  if label_str not in label_map:
276
  raise gr.Error("Invalid condition selected")
277
 
278
- labels = torch.zeros(num_images, NUM_CLASSES, device=device)
 
279
  labels[:, label_map[label_str]] = 1
280
 
281
  try:
@@ -286,29 +281,24 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
286
 
287
  with torch.no_grad():
288
  images = loaded_model.sample(
289
- num_images=num_images,
290
- img_size=IMG_SIZE,
291
- num_classes=NUM_CLASSES,
292
- labels=labels,
293
- device=device,
294
- progress_callback=progress_callback
295
  )
296
-
297
  if images is None:
298
  return None, None
299
-
300
  processed_images = []
301
  for img in images:
302
- img_np = img.cpu().numpy()
303
-
304
- # Convert to PIL with enhanced contrast
305
  img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
306
  pil_img = Image.fromarray(img_np)
307
-
308
- # Apply additional PIL-based enhancements
309
- pil_img = pil_img.filter(ImageFilter.SMOOTH_MORE)
310
  processed_images.append(pil_img)
311
-
312
 
313
  if num_images == 1:
314
  return processed_images[0], processed_images
 
125
  self.model = model
126
  self.timesteps = timesteps
127
 
128
+ # Noise schedule from working code
129
  beta_start = 0.0001
130
  beta_end = 0.02
131
  self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
132
 
133
  self.alphas = 1. - self.betas
134
+ self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
135
+ self.register_buffer('sqrt_one_minus_alpha_bars', torch.sqrt(1. - self.alpha_bars))
 
 
 
136
  self.register_buffer('sqrt_recip_alphas', torch.sqrt(1. / self.alphas))
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  @torch.no_grad()
139
  def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
140
+ """Improved sampling method based on working code"""
141
+ x_t = torch.randn((num_images, 3, img_size, img_size), device=device)
 
142
 
143
+ # Handle labels (class indices or one-hot)
144
  if labels.ndim == 1:
145
  labels = torch.zeros(num_images, num_classes, device=device).scatter_(1, labels.unsqueeze(1), 1)
146
+ else:
147
+ labels = labels.float().to(device)
148
+
149
  for t in reversed(range(self.timesteps)):
150
  if cancel_event.is_set():
151
  return None
152
 
153
+ t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
 
154
 
155
+ # Predict noise with model
156
+ pred_noise = self.model(x_t, labels, t_tensor)
157
 
158
+ # Calculate coefficients from working code
159
+ beta_t = self.betas[t].to(device)
160
+ alpha_t = self.alphas[t].to(device)
161
+ alpha_bar_t = self.alpha_bars[t].to(device)
162
 
163
+ # Improved reverse diffusion step
164
+ mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * pred_noise)
165
+ variance = beta_t
166
 
 
167
  if t > 0:
168
+ noise = torch.randn_like(x_t)
 
169
  else:
170
  noise = torch.zeros_like(x_t)
171
 
172
+ x_t = mean + torch.sqrt(variance) * noise
 
173
 
174
  if progress_callback:
175
  progress_callback((self.timesteps - t) / self.timesteps)
176
 
177
+ # Improved normalization from working code
178
+ x_t = torch.clamp(x_t, -1., 1.)
179
+
180
+ # Denormalize using ImageNet stats (from working code)
181
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
182
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
183
+ x_t = std * x_t + mean
184
+ x_t = torch.clamp(x_t, 0., 1.)
185
+
186
  return x_t
187
 
188
  def _post_process(self, images):
 
269
  if label_str not in label_map:
270
  raise gr.Error("Invalid condition selected")
271
 
272
+ # Create one-hot encoded labels
273
+ labels = torch.zeros(num_images, NUM_CLASSES)
274
  labels[:, label_map[label_str]] = 1
275
 
276
  try:
 
281
 
282
  with torch.no_grad():
283
  images = loaded_model.sample(
284
+ num_images=num_images,
285
+ img_size=IMG_SIZE,
286
+ num_classes=NUM_CLASSES,
287
+ labels=labels,
288
+ device=device,
289
+ progress_callback=progress_callback
290
  )
291
+
292
  if images is None:
293
  return None, None
294
+
295
  processed_images = []
296
  for img in images:
297
+ # Convert to numpy and permute dimensions (C,H,W) -> (H,W,C)
298
+ img_np = img.cpu().permute(1, 2, 0).numpy()
 
299
  img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
300
  pil_img = Image.fromarray(img_np)
 
 
 
301
  processed_images.append(pil_img)
 
302
 
303
  if num_images == 1:
304
  return processed_images[0], processed_images