Vedansh-7 commited on
Commit
47535f8
·
verified ·
1 Parent(s): 542a20e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -50
app.py CHANGED
@@ -7,6 +7,7 @@ import math
7
  import os
8
  from threading import Event
9
  import traceback
 
10
 
11
  # Constants
12
  IMG_SIZE = 128
@@ -153,75 +154,74 @@ class DiffusionModel(nn.Module):
153
 
154
  @torch.no_grad()
155
  def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
156
- # Constants
157
- NOISE_SCALE = 0.9
158
- NOISE_MIN_FACTOR = 0.6
159
- SHARPEN_STRENGTH = 1.4
160
- EDGE_BOOST = 0.15
161
- EPS = 1e-8
162
-
163
- # Initialize with scaled noise
164
- x_t = torch.randn(num_images, 3, img_size, img_size, device=device) * NOISE_SCALE
165
 
166
- # Label processing
167
  if labels.ndim == 1:
168
- labels = torch.zeros(num_images, num_classes, device=device).scatter_(1, labels.unsqueeze(1), 1)
 
 
169
  else:
170
  labels = labels.to(device)
171
 
172
- # Reverse diffusion process
173
  for t in reversed(range(self.timesteps)):
174
  if cancel_event.is_set():
175
  return None
176
 
177
- t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float32)
178
  predicted_noise = self.model(x_t, labels, t_tensor)
179
 
180
- beta_t = self.betas[t].to(device).float()
181
- alpha_t = self.alphas[t].to(device).float()
182
- alpha_bar_t = self.alpha_bars[t].to(device).float()
 
183
 
184
- # Stable mean calculation
185
- mean = (1 / (torch.sqrt(alpha_t) + EPS)) * (
186
- x_t - (beta_t / (torch.sqrt(1 - alpha_bar_t) + EPS)) * predicted_noise
187
- )
188
-
189
- # Dynamic noise scaling
190
  if t > 0:
191
- noise_factor = NOISE_MIN_FACTOR + (1 - NOISE_MIN_FACTOR) * (t / self.timesteps)
192
- noise = torch.randn_like(x_t) * noise_factor
193
  else:
194
  noise = torch.zeros_like(x_t)
195
 
196
- x_t = mean + torch.sqrt(beta_t) * noise
197
 
198
- if progress_callback is not None:
199
  progress_callback((self.timesteps - t) / self.timesteps)
200
 
201
- # Post-processing
202
- x_0 = self._post_process(x_t, device)
203
- return x_0
204
- def _post_process(self, x_t, device):
205
- """Apply denormalization and image enhancement"""
206
- # Denormalization
207
- norm_mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
208
- norm_std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
209
- x_0 = torch.clamp(norm_std * torch.clamp(x_t, -1., 1.) + norm_mean, 0., 1.)
210
-
211
- # Edge-preserving smoothing
212
- blurred = torch.nn.functional.avg_pool2d(x_0, kernel_size=5, stride=1, padding=2)
213
- mask = torch.abs(x_0 - blurred) < 0.1
214
- x_0 = torch.where(mask, 0.7*x_0 + 0.3*blurred, x_0)
215
-
216
- # Adaptive sharpening
217
- low_pass = torch.nn.functional.avg_pool2d(x_0, kernel_size=3, stride=1, padding=1)
218
- x_0 = torch.clamp((1 + self.SHARPEN_STRENGTH) * x_0 - self.SHARPEN_STRENGTH * low_pass, 0, 1)
219
-
220
- # Edge boost
221
- edges = x_0 - torch.nn.functional.avg_pool2d(x_0, kernel_size=5, stride=1, padding=2)
222
- return torch.clamp(x_0 + edges * self.EDGE_BOOST, 0, 1)
223
 
224
-
 
 
 
 
 
 
 
 
 
 
 
 
225
  def load_model(model_path, device):
226
  unet_model = UNet(num_classes=NUM_CLASSES).to(device)
227
  diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
@@ -315,7 +315,7 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
315
  raise gr.Error(f"Generation failed: {str(e)}")
316
  finally:
317
  torch.cuda.empty_cache()
318
-
319
  # Load model
320
  MODEL_NAME = "model_weights.pth"
321
  model_path = MODEL_NAME
 
7
  import os
8
  from threading import Event
9
  import traceback
10
+ import cv2
11
 
12
  # Constants
13
  IMG_SIZE = 128
 
154
 
155
  @torch.no_grad()
156
  def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
157
+ # Start with random noise
158
+ x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
 
 
 
 
 
 
 
159
 
160
+ # Label handling (one-hot if needed)
161
  if labels.ndim == 1:
162
+ labels_one_hot = torch.zeros(num_images, num_classes).to(device)
163
+ labels_one_hot[torch.arange(num_images), labels] = 1
164
+ labels = labels_one_hot
165
  else:
166
  labels = labels.to(device)
167
 
168
+ # ---- REVERTED SAMPLING LOOP WITH NOISE REDUCTION ----
169
  for t in reversed(range(self.timesteps)):
170
  if cancel_event.is_set():
171
  return None
172
 
173
+ t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
174
  predicted_noise = self.model(x_t, labels, t_tensor)
175
 
176
+ # Calculate coefficients
177
+ beta_t = self.betas[t].to(device)
178
+ alpha_t = self.alphas[t].to(device)
179
+ alpha_bar_t = self.alpha_bars[t].to(device)
180
 
181
+ mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
182
+ variance = beta_t
183
+
184
+ # Reduced noise injection with lower multiplier
 
 
185
  if t > 0:
186
+ noise = torch.randn_like(x_t) * 0.8 # Reduced noise by 20%
 
187
  else:
188
  noise = torch.zeros_like(x_t)
189
 
190
+ x_t = mean + torch.sqrt(variance) * noise
191
 
192
+ if progress_callback:
193
  progress_callback((self.timesteps - t) / self.timesteps)
194
 
195
+ # Clamp and denormalize
196
+ x_0 = torch.clamp(x_t, -1., 1.)
197
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
198
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
199
+ x_0 = std * x_0 + mean
200
+ x_0 = torch.clamp(x_0, 0., 1.)
201
+
202
+ # ---- ENHANCED SHARPENING ----
203
+ # First apply mild bilateral filtering to reduce noise while preserving edges
204
+ x_np = x_0.cpu().permute(0, 2, 3, 1).numpy()
205
+ filtered = []
206
+ for img in x_np:
207
+ img = (img * 255).astype(np.uint8)
208
+ filtered_img = cv2.bilateralFilter(img, d=5, sigmaColor=15, sigmaSpace=15)
209
+ filtered.append(filtered_img / 255.0)
210
+ x_0 = torch.tensor(np.array(filtered), device=device).permute(0, 3, 1, 2)
 
 
 
 
 
 
211
 
212
+ # Then apply stronger unsharp masking
213
+ kernel = torch.ones(3, 1, 5, 5, device=device) / 75
214
+ kernel = kernel.to(x_0.dtype)
215
+ blurred = torch.nn.functional.conv2d(
216
+ x_0,
217
+ kernel,
218
+ padding=2,
219
+ groups=3
220
+ )
221
+ x_0 = torch.clamp(1.5 * x_0 - 0.5 * blurred, 0., 1.) # Increased sharpening factor
222
+
223
+ return x_0
224
+
225
  def load_model(model_path, device):
226
  unet_model = UNet(num_classes=NUM_CLASSES).to(device)
227
  diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
 
315
  raise gr.Error(f"Generation failed: {str(e)}")
316
  finally:
317
  torch.cuda.empty_cache()
318
+
319
  # Load model
320
  MODEL_NAME = "model_weights.pth"
321
  model_path = MODEL_NAME