Vedansh-7 commited on
Commit
8cc9c66
·
verified ·
1 Parent(s): fdbdf55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -67
app.py CHANGED
@@ -10,7 +10,7 @@ import traceback
10
 
11
  # Constants
12
  IMG_SIZE = 128
13
- TIMESTEPS = 300
14
  NUM_CLASSES = 2
15
 
16
  # Global Cancellation Flag
@@ -135,47 +135,56 @@ class DiffusionModel(nn.Module):
135
  self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
136
 
137
  @torch.no_grad()
138
- def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
139
- # Initialize with noise
140
- x_t = torch.randn((num_images, 3, img_size, img_size), device=device, dtype=torch.float32)
141
-
142
- # Convert labels to proper format
143
- if labels.ndim == 1:
144
- labels_one_hot = torch.zeros(num_images, num_classes, device=device)
145
- labels_one_hot[torch.arange(num_images), labels] = 1
146
- labels = labels_one_hot
147
- else:
148
- labels = labels.to(device)
149
 
150
- for i in reversed(range(0, self.timesteps)):
151
- if cancel_event.is_set():
152
- return None
153
-
154
- t = torch.full((num_images,), i, device=device, dtype=torch.long)
155
 
156
- # Model prediction with type stability
157
- pred_noise = self.model(x_t, labels, t.float())
158
-
159
- # Calculate diffusion parameters
160
- beta_t = self.betas[t].view(-1, 1, 1, 1).to(device)
161
- alpha_t = self.alphas[t].view(-1, 1, 1, 1).to(device)
162
- alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1).to(device)
163
-
164
- # Improved denoising step (Fix 2)
165
- if i > 0:
166
- noise = torch.randn_like(x_t)
167
- else:
168
- noise = torch.zeros_like(x_t)
169
-
170
- x_t = (x_t - (1 - alpha_t)/torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_t)
171
- x_t += noise * torch.sqrt(beta_t)
172
 
173
- if progress_callback:
174
- progress_callback((self.timesteps - i) / self.timesteps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
- # Fix 3: Simplified scaling
177
- x_t = torch.clamp(x_t, -1., 1.)
178
- return (x_t + 1) / 2 # Scale to [0,1]
179
 
180
  def load_model(model_path, device):
181
  unet = UNet(num_classes=NUM_CLASSES).to(device)
@@ -231,7 +240,7 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
231
  if label_str not in label_map:
232
  raise gr.Error("Invalid condition selected")
233
 
234
- labels = torch.zeros(num_images, NUM_CLASSES, device=device, dtype=torch.float32)
235
  labels[:, label_map[label_str]] = 1
236
 
237
  try:
@@ -242,10 +251,9 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
242
 
243
  with torch.no_grad():
244
  print(f"Generating {num_images} images for {label_str}")
245
- print(f"Labels shape: {labels.shape}, device: {labels.device}")
246
-
247
  images = loaded_model.sample(
248
  num_images=num_images,
 
249
  img_size=IMG_SIZE,
250
  num_classes=NUM_CLASSES,
251
  labels=labels,
@@ -256,17 +264,15 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
256
  if images is None:
257
  return None, None
258
 
259
- # Diagnostic print
260
  print(f"Generated images range: {images.min().item():.3f}, {images.max().item():.3f}")
261
 
262
  processed_images = []
263
  for img in images:
264
- # Fix 3: Improved image conversion
265
- img_np = (img.cpu().numpy().transpose(1, 2, 0) * 255).clip(0, 255).astype(np.uint8)
266
- print(f"Image range after conversion: {img_np.min()}, {img_np.max()}")
267
 
268
- if img_np.shape[2] == 1: # Handle grayscale if needed
269
- img_np = img_np.squeeze(-1)
270
  pil_img = Image.fromarray(img_np)
271
  processed_images.append(pil_img)
272
 
@@ -276,30 +282,11 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
276
  else:
277
  return None, processed_images
278
 
279
- except torch.cuda.OutOfMemoryError:
280
- torch.cuda.empty_cache()
281
- raise gr.Error("Out of GPU memory - try generating fewer images")
282
  except Exception as e:
283
  traceback.print_exc()
284
- if str(e) != "Generation was cancelled by user":
285
- raise gr.Error(f"Generation failed: {str(e)}")
286
- return None, None
287
  finally:
288
  torch.cuda.empty_cache()
289
-
290
- # Load model
291
- MODEL_NAME = "model_weights.pth" # Updated to look in root folder
292
- model_path = MODEL_NAME
293
- print("Loading model...")
294
- try:
295
- loaded_model = load_model(model_path, device)
296
- print("Model loaded successfully!")
297
- except Exception as e:
298
- print(f"Failed to load model: {e}")
299
- # Create a dummy model for demo purposes
300
- print("Creating dummy model for demonstration")
301
- loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES)).to(device)
302
-
303
 
304
  # Gradio UI
305
  with gr.Blocks(theme=gr.themes.Soft(
 
10
 
11
  # Constants
12
  IMG_SIZE = 128
13
+ TIMESTEPS = 500
14
  NUM_CLASSES = 2
15
 
16
  # Global Cancellation Flag
 
135
  self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
136
 
137
  @torch.no_grad()
138
+ def sample(model, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
139
+ # Initialize with properly scaled noise
140
+ x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.5 # Reduced initial noise scale
141
+
142
+ # Convert labels to proper format
143
+ if labels.ndim == 1:
144
+ labels_one_hot = torch.zeros(num_images, num_classes, device=device)
145
+ labels_one_hot[torch.arange(num_images), labels] = 1
146
+ labels = labels_one_hot
147
+ else:
148
+ labels = labels.float().to(device)
149
 
150
+ # Reverse diffusion process
151
+ for t in reversed(range(timesteps)):
152
+ if cancel_event.is_set():
153
+ return None
 
154
 
155
+ t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
156
+
157
+ # Model prediction with proper scaling
158
+ pred_noise = model.model(x_t, labels, t_tensor.float())
159
+
160
+ # Calculate diffusion parameters
161
+ alpha_t = model.alphas[t].to(device)
162
+ alpha_bar_t = model.alpha_bars[t].to(device)
163
+ beta_t = model.betas[t].to(device)
164
+
165
+ # Improved denoising step
166
+ if t > 0:
167
+ noise = torch.randn_like(x_t) * 0.5 # Reduced noise scale
168
+ else:
169
+ noise = torch.zeros_like(x_t)
 
170
 
171
+ # More stable prediction
172
+ x_t = (x_t - (1 - alpha_t)/torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_t)
173
+ x_t = x_t + noise * torch.sqrt(beta_t)
174
+
175
+ if progress_callback:
176
+ progress_callback((timesteps - t) / timesteps)
177
+
178
+ # Better image normalization
179
+ x_t = torch.clamp(x_t, -1., 1.)
180
+
181
+ # Alternative normalization approach
182
+ min_val = x_t.min()
183
+ max_val = x_t.max()
184
+ x_t = (x_t - min_val) / (max_val - min_val + 1e-8) # Ensure we don't divide by zero
185
+
186
+ return x_t
187
 
 
 
 
188
 
189
  def load_model(model_path, device):
190
  unet = UNet(num_classes=NUM_CLASSES).to(device)
 
240
  if label_str not in label_map:
241
  raise gr.Error("Invalid condition selected")
242
 
243
+ labels = torch.zeros(num_images, NUM_CLASSES, device=device)
244
  labels[:, label_map[label_str]] = 1
245
 
246
  try:
 
251
 
252
  with torch.no_grad():
253
  print(f"Generating {num_images} images for {label_str}")
 
 
254
  images = loaded_model.sample(
255
  num_images=num_images,
256
+ timesteps=TIMESTEPS,
257
  img_size=IMG_SIZE,
258
  num_classes=NUM_CLASSES,
259
  labels=labels,
 
264
  if images is None:
265
  return None, None
266
 
 
267
  print(f"Generated images range: {images.min().item():.3f}, {images.max().item():.3f}")
268
 
269
  processed_images = []
270
  for img in images:
271
+ # Convert to numpy and ensure proper range
272
+ img_np = img.cpu().numpy().transpose(1, 2, 0)
273
+ img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
274
 
275
+ # Convert to PIL Image
 
276
  pil_img = Image.fromarray(img_np)
277
  processed_images.append(pil_img)
278
 
 
282
  else:
283
  return None, processed_images
284
 
 
 
 
285
  except Exception as e:
286
  traceback.print_exc()
287
+ raise gr.Error(f"Generation failed: {str(e)}")
 
 
288
  finally:
289
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
  # Gradio UI
292
  with gr.Blocks(theme=gr.themes.Soft(