Vedansh-7 commited on
Commit
542a20e
·
verified ·
1 Parent(s): 27df47f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -24
app.py CHANGED
@@ -153,49 +153,75 @@ class DiffusionModel(nn.Module):
153
 
154
  @torch.no_grad()
155
  def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
156
- x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
 
 
 
 
 
 
 
 
157
 
 
158
  if labels.ndim == 1:
159
- labels_one_hot = torch.zeros(num_images, num_classes).to(device)
160
- labels_one_hot[torch.arange(num_images), labels] = 1
161
- labels = labels_one_hot
162
  else:
163
  labels = labels.to(device)
164
 
 
165
  for t in reversed(range(self.timesteps)):
166
  if cancel_event.is_set():
167
  return None
168
 
169
- t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
170
  predicted_noise = self.model(x_t, labels, t_tensor)
171
 
172
- beta_t = self.betas[t].to(device)
173
- alpha_t = self.alphas[t].to(device)
174
- alpha_bar_t = self.alpha_bars[t].to(device)
175
-
176
- mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
177
- variance = beta_t
178
 
 
 
 
 
 
 
179
  if t > 0:
180
- noise = torch.randn_like(x_t)
 
181
  else:
182
  noise = torch.zeros_like(x_t)
183
 
184
- x_t = mean + torch.sqrt(variance) * noise
185
 
186
- if progress_callback:
187
  progress_callback((self.timesteps - t) / self.timesteps)
188
 
189
- x_0 = torch.clamp(x_t, -1., 1.)
190
-
191
- # Normalization
192
- mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
193
- std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
194
- x_0 = std * x_0 + mean
195
- x_0 = torch.clamp(x_0, 0., 1.)
196
-
197
  return x_0
198
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  def load_model(model_path, device):
200
  unet_model = UNet(num_classes=NUM_CLASSES).to(device)
201
  diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
@@ -289,7 +315,7 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
289
  raise gr.Error(f"Generation failed: {str(e)}")
290
  finally:
291
  torch.cuda.empty_cache()
292
-
293
  # Load model
294
  MODEL_NAME = "model_weights.pth"
295
  model_path = MODEL_NAME
 
153
 
154
  @torch.no_grad()
155
  def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
156
+ # Constants
157
+ NOISE_SCALE = 0.9
158
+ NOISE_MIN_FACTOR = 0.6
159
+ SHARPEN_STRENGTH = 1.4
160
+ EDGE_BOOST = 0.15
161
+ EPS = 1e-8
162
+
163
+ # Initialize with scaled noise
164
+ x_t = torch.randn(num_images, 3, img_size, img_size, device=device) * NOISE_SCALE
165
 
166
+ # Label processing
167
  if labels.ndim == 1:
168
+ labels = torch.zeros(num_images, num_classes, device=device).scatter_(1, labels.unsqueeze(1), 1)
 
 
169
  else:
170
  labels = labels.to(device)
171
 
172
+ # Reverse diffusion process
173
  for t in reversed(range(self.timesteps)):
174
  if cancel_event.is_set():
175
  return None
176
 
177
+ t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float32)
178
  predicted_noise = self.model(x_t, labels, t_tensor)
179
 
180
+ beta_t = self.betas[t].to(device).float()
181
+ alpha_t = self.alphas[t].to(device).float()
182
+ alpha_bar_t = self.alpha_bars[t].to(device).float()
 
 
 
183
 
184
+ # Stable mean calculation
185
+ mean = (1 / (torch.sqrt(alpha_t) + EPS)) * (
186
+ x_t - (beta_t / (torch.sqrt(1 - alpha_bar_t) + EPS)) * predicted_noise
187
+ )
188
+
189
+ # Dynamic noise scaling
190
  if t > 0:
191
+ noise_factor = NOISE_MIN_FACTOR + (1 - NOISE_MIN_FACTOR) * (t / self.timesteps)
192
+ noise = torch.randn_like(x_t) * noise_factor
193
  else:
194
  noise = torch.zeros_like(x_t)
195
 
196
+ x_t = mean + torch.sqrt(beta_t) * noise
197
 
198
+ if progress_callback is not None:
199
  progress_callback((self.timesteps - t) / self.timesteps)
200
 
201
+ # Post-processing
202
+ x_0 = self._post_process(x_t, device)
 
 
 
 
 
 
203
  return x_0
204
+ def _post_process(self, x_t, device):
205
+ """Apply denormalization and image enhancement"""
206
+ # Denormalization
207
+ norm_mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
208
+ norm_std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
209
+ x_0 = torch.clamp(norm_std * torch.clamp(x_t, -1., 1.) + norm_mean, 0., 1.)
210
+
211
+ # Edge-preserving smoothing
212
+ blurred = torch.nn.functional.avg_pool2d(x_0, kernel_size=5, stride=1, padding=2)
213
+ mask = torch.abs(x_0 - blurred) < 0.1
214
+ x_0 = torch.where(mask, 0.7*x_0 + 0.3*blurred, x_0)
215
+
216
+ # Adaptive sharpening
217
+ low_pass = torch.nn.functional.avg_pool2d(x_0, kernel_size=3, stride=1, padding=1)
218
+ x_0 = torch.clamp((1 + self.SHARPEN_STRENGTH) * x_0 - self.SHARPEN_STRENGTH * low_pass, 0, 1)
219
+
220
+ # Edge boost
221
+ edges = x_0 - torch.nn.functional.avg_pool2d(x_0, kernel_size=5, stride=1, padding=2)
222
+ return torch.clamp(x_0 + edges * self.EDGE_BOOST, 0, 1)
223
+
224
+
225
  def load_model(model_path, device):
226
  unet_model = UNet(num_classes=NUM_CLASSES).to(device)
227
  diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
 
315
  raise gr.Error(f"Generation failed: {str(e)}")
316
  finally:
317
  torch.cuda.empty_cache()
318
+
319
  # Load model
320
  MODEL_NAME = "model_weights.pth"
321
  model_path = MODEL_NAME