Vedansh-7 commited on
Commit
49fdbe4
·
verified ·
1 Parent(s): 3886050

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -44
app.py CHANGED
@@ -125,7 +125,6 @@ class DiffusionModel(nn.Module):
125
  self.model = model
126
  self.timesteps = timesteps
127
 
128
- # Use the exact same noise schedule as Colab
129
  beta_start = 0.0001
130
  beta_end = 0.02
131
  self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
@@ -134,29 +133,28 @@ class DiffusionModel(nn.Module):
134
 
135
  @torch.no_grad()
136
  def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
137
- """Identical implementation to Colab version"""
138
- # Start with random noise (same scale)
139
- x_t = torch.randn((num_images, 3, img_size, img_size), device=device)
140
-
141
- # Identical label handling
142
  if labels.ndim == 1:
143
- labels = torch.zeros(num_images, num_classes, device=device).scatter_(1, labels.unsqueeze(1), 1)
144
- labels = labels.to(device)
 
 
 
145
 
146
- # Same sampling loop
147
  for t in reversed(range(self.timesteps)):
148
  if cancel_event.is_set():
149
  return None
150
 
151
- t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
 
152
  predicted_noise = self.model(x_t, labels, t_tensor)
153
 
154
- # Identical coefficients calculation
155
  beta_t = self.betas[t].to(device)
156
  alpha_t = self.alphas[t].to(device)
157
  alpha_bar_t = self.alpha_bars[t].to(device)
158
 
159
- # Same mean/variance calculation
160
  mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
161
  variance = beta_t
162
 
@@ -170,35 +168,15 @@ class DiffusionModel(nn.Module):
170
  if progress_callback:
171
  progress_callback((self.timesteps - t) / self.timesteps)
172
 
173
- # Identical denormalization
174
- x_t = torch.clamp(x_t, -1., 1.)
175
  mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
176
  std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
177
- x_t = std * x_t + mean
178
- x_t = torch.clamp(x_t, 0., 1.)
179
-
180
- return x_t
181
 
182
- def _post_process(self, images):
183
- """Apply post-processing to reduce noise and enhance contrast"""
184
- # Normalize to [0,1]
185
- images = torch.clamp(images, -1, 1)
186
- images = (images + 1) / 2
187
-
188
- # Apply mild blur (convert NHWC to NCHW for conv2d)
189
- if images.dim() == 4 and images.shape[-1] != 3: # NCHW format
190
- images = images.permute(0, 2, 3, 1)
191
-
192
- x = images.permute(0, 3, 1, 2) # NHWC to NCHW
193
- x = torch.nn.functional.conv2d(x, self.blur_kernel, padding=1, groups=3)
194
- images = x.permute(0, 2, 3, 1) # NCHW to NHWC
195
-
196
- # Contrast adjustment
197
- mean_val = images.mean(dim=(1,2,3), keepdim=True)
198
- images = (images - mean_val) * 1.2 + mean_val
199
-
200
- return torch.clamp(images, 0, 1)
201
-
202
  def load_model(model_path, device):
203
  unet = UNet(num_classes=NUM_CLASSES).to(device)
204
  diffusion_model = DiffusionModel(unet).to(device)
@@ -207,20 +185,17 @@ def load_model(model_path, device):
207
  try:
208
  checkpoint = torch.load(model_path, map_location=device)
209
 
210
- # Handle both full model and state_dict loading
211
  if 'model_state_dict' in checkpoint:
212
  state_dict = checkpoint['model_state_dict']
213
  else:
214
  state_dict = checkpoint
215
 
216
- # Handle both prefixed and non-prefixed state dicts
217
  if all(k.startswith('model.') for k in state_dict.keys()):
218
  state_dict = {k[6:]: v for k, v in state_dict.items()}
219
 
220
  unet.load_state_dict(state_dict, strict=False)
221
  print("Model loaded successfully")
222
 
223
- # Verify model loading
224
  test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
225
  test_labels = torch.zeros(1, NUM_CLASSES).to(device)
226
  test_time = torch.tensor([1]).to(device)
@@ -244,7 +219,6 @@ try:
244
  print("Model loaded successfully!")
245
  except Exception as e:
246
  print(f"Failed to load model: {e}")
247
- # Create a dummy model if loading fails
248
  print("Creating dummy model for demonstration")
249
  loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES)).to(device)
250
 
@@ -263,7 +237,6 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
263
  if label_str not in label_map:
264
  raise gr.Error("Invalid condition selected")
265
 
266
- # Create one-hot encoded labels
267
  labels = torch.zeros(num_images, NUM_CLASSES)
268
  labels[:, label_map[label_str]] = 1
269
 
@@ -288,7 +261,6 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
288
 
289
  processed_images = []
290
  for img in images:
291
- # Convert to numpy and permute dimensions (C,H,W) -> (H,W,C)
292
  img_np = img.cpu().permute(1, 2, 0).numpy()
293
  img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
294
  pil_img = Image.fromarray(img_np)
 
125
  self.model = model
126
  self.timesteps = timesteps
127
 
 
128
  beta_start = 0.0001
129
  beta_end = 0.02
130
  self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
 
133
 
134
  @torch.no_grad()
135
  def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
136
+ """Your exact sampling function from Colab"""
137
+ x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
138
+
 
 
139
  if labels.ndim == 1:
140
+ labels_one_hot = torch.zeros(num_images, num_classes).to(device)
141
+ labels_one_hot[torch.arange(num_images), labels] = 1
142
+ labels = labels_one_hot
143
+ else:
144
+ labels = labels.to(device)
145
 
 
146
  for t in reversed(range(self.timesteps)):
147
  if cancel_event.is_set():
148
  return None
149
 
150
+ t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float) # Pass time as float
151
+
152
  predicted_noise = self.model(x_t, labels, t_tensor)
153
 
 
154
  beta_t = self.betas[t].to(device)
155
  alpha_t = self.alphas[t].to(device)
156
  alpha_bar_t = self.alpha_bars[t].to(device)
157
 
 
158
  mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
159
  variance = beta_t
160
 
 
168
  if progress_callback:
169
  progress_callback((self.timesteps - t) / self.timesteps)
170
 
171
+ x_0 = torch.clamp(x_t, -1., 1.)
172
+
173
  mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
174
  std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
175
+ x_0 = std * x_0 + mean
176
+ x_0 = torch.clamp(x_0, 0., 1.)
177
+
178
+ return x_0
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  def load_model(model_path, device):
181
  unet = UNet(num_classes=NUM_CLASSES).to(device)
182
  diffusion_model = DiffusionModel(unet).to(device)
 
185
  try:
186
  checkpoint = torch.load(model_path, map_location=device)
187
 
 
188
  if 'model_state_dict' in checkpoint:
189
  state_dict = checkpoint['model_state_dict']
190
  else:
191
  state_dict = checkpoint
192
 
 
193
  if all(k.startswith('model.') for k in state_dict.keys()):
194
  state_dict = {k[6:]: v for k, v in state_dict.items()}
195
 
196
  unet.load_state_dict(state_dict, strict=False)
197
  print("Model loaded successfully")
198
 
 
199
  test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
200
  test_labels = torch.zeros(1, NUM_CLASSES).to(device)
201
  test_time = torch.tensor([1]).to(device)
 
219
  print("Model loaded successfully!")
220
  except Exception as e:
221
  print(f"Failed to load model: {e}")
 
222
  print("Creating dummy model for demonstration")
223
  loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES)).to(device)
224
 
 
237
  if label_str not in label_map:
238
  raise gr.Error("Invalid condition selected")
239
 
 
240
  labels = torch.zeros(num_images, NUM_CLASSES)
241
  labels[:, label_map[label_str]] = 1
242
 
 
261
 
262
  processed_images = []
263
  for img in images:
 
264
  img_np = img.cpu().permute(1, 2, 0).numpy()
265
  img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
266
  pil_img = Image.fromarray(img_np)