Vedansh-7 commited on
Commit
f6ed1f7
·
1 Parent(s): 3aca900

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -56
app.py CHANGED
@@ -134,56 +134,46 @@ class DiffusionModel(nn.Module):
134
  self.alphas = 1. - self.betas
135
  self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
136
 
137
- @torch.no_grad()
138
- def sample(model, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
139
- # Initialize with properly scaled noise
140
- x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.5 # Reduced initial noise scale
141
-
142
- # Convert labels to proper format
143
- if labels.ndim == 1:
144
- labels_one_hot = torch.zeros(num_images, num_classes, device=device)
145
- labels_one_hot[torch.arange(num_images), labels] = 1
146
- labels = labels_one_hot
147
- else:
148
- labels = labels.float().to(device)
149
-
150
- # Reverse diffusion process
151
- for t in reversed(range(timesteps)):
152
- if cancel_event.is_set():
153
- return None
154
-
155
- t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
156
-
157
- # Model prediction with proper scaling
158
- pred_noise = model.model(x_t, labels, t_tensor.float())
159
-
160
- # Calculate diffusion parameters
161
- alpha_t = model.alphas[t].to(device)
162
- alpha_bar_t = model.alpha_bars[t].to(device)
163
- beta_t = model.betas[t].to(device)
164
 
165
- # Improved denoising step
166
- if t > 0:
167
- noise = torch.randn_like(x_t) * 0.5 # Reduced noise scale
 
168
  else:
169
- noise = torch.zeros_like(x_t)
 
 
 
 
 
 
170
 
171
- # More stable prediction
172
- x_t = (x_t - (1 - alpha_t)/torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_t)
173
- x_t = x_t + noise * torch.sqrt(beta_t)
174
-
175
- if progress_callback:
176
- progress_callback((timesteps - t) / timesteps)
 
 
 
 
 
 
 
 
 
 
177
 
178
- # Better image normalization
179
- x_t = torch.clamp(x_t, -1., 1.)
180
-
181
- # Alternative normalization approach
182
- min_val = x_t.min()
183
- max_val = x_t.max()
184
- x_t = (x_t - min_val) / (max_val - min_val + 1e-8) # Ensure we don't divide by zero
185
-
186
- return x_t
187
 
188
 
189
  def load_model(model_path, device):
@@ -210,7 +200,6 @@ def load_model(model_path, device):
210
  # Verify model loading
211
  test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
212
  test_labels = torch.zeros(1, NUM_CLASSES).to(device)
213
- test_labels[0, 0] = 1
214
  test_time = torch.tensor([1]).to(device)
215
  output = unet(test_input, test_labels, test_time)
216
  print(f"Model test output shape: {output.shape}")
@@ -223,6 +212,18 @@ def load_model(model_path, device):
223
 
224
  diffusion_model.eval()
225
  return diffusion_model
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  def cancel_generation():
228
  cancel_event.set()
@@ -232,7 +233,6 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
232
  global loaded_model
233
  cancel_event.clear()
234
 
235
- # Input validation
236
  if num_images < 1 or num_images > 10:
237
  raise gr.Error("Number of images must be between 1 and 10")
238
 
@@ -250,7 +250,6 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
250
  raise gr.Error("Generation was cancelled by user")
251
 
252
  with torch.no_grad():
253
- print(f"Generating {num_images} images for {label_str}")
254
  images = loaded_model.sample(
255
  num_images=num_images,
256
  timesteps=TIMESTEPS,
@@ -264,19 +263,13 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
264
  if images is None:
265
  return None, None
266
 
267
- print(f"Generated images range: {images.min().item():.3f}, {images.max().item():.3f}")
268
-
269
  processed_images = []
270
  for img in images:
271
- # Convert to numpy and ensure proper range
272
  img_np = img.cpu().numpy().transpose(1, 2, 0)
273
  img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
274
-
275
- # Convert to PIL Image
276
  pil_img = Image.fromarray(img_np)
277
  processed_images.append(pil_img)
278
 
279
- # Return appropriate outputs based on count
280
  if num_images == 1:
281
  return processed_images[0], processed_images
282
  else:
@@ -287,7 +280,7 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
287
  raise gr.Error(f"Generation failed: {str(e)}")
288
  finally:
289
  torch.cuda.empty_cache()
290
-
291
  # Gradio UI
292
  with gr.Blocks(theme=gr.themes.Soft(
293
  primary_hue="violet",
 
134
  self.alphas = 1. - self.betas
135
  self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
136
 
137
+ @torch.no_grad()
138
+ def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
139
+ x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ if labels.ndim == 1:
142
+ labels_one_hot = torch.zeros(num_images, num_classes, device=device)
143
+ labels_one_hot[torch.arange(num_images), labels] = 1
144
+ labels = labels_one_hot
145
  else:
146
+ labels = labels.float().to(device)
147
+
148
+ for t in reversed(range(timesteps)):
149
+ if cancel_event.is_set():
150
+ return None
151
+
152
+ t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
153
 
154
+ pred_noise = self.model(x_t, labels, t_tensor.float())
155
+
156
+ alpha_t = self.alphas[t].to(device)
157
+ alpha_bar_t = self.alpha_bars[t].to(device)
158
+ beta_t = self.betas[t].to(device)
159
+
160
+ if t > 0:
161
+ noise = torch.randn_like(x_t) * 0.5
162
+ else:
163
+ noise = torch.zeros_like(x_t)
164
+
165
+ x_t = (x_t - (1 - alpha_t)/torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_t)
166
+ x_t = x_t + noise * torch.sqrt(beta_t)
167
+
168
+ if progress_callback:
169
+ progress_callback((timesteps - t) / timesteps)
170
 
171
+ x_t = torch.clamp(x_t, -1., 1.)
172
+ min_val = x_t.min()
173
+ max_val = x_t.max()
174
+ x_t = (x_t - min_val) / (max_val - min_val + 1e-8)
175
+
176
+ return x_t
 
 
 
177
 
178
 
179
  def load_model(model_path, device):
 
200
  # Verify model loading
201
  test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
202
  test_labels = torch.zeros(1, NUM_CLASSES).to(device)
 
203
  test_time = torch.tensor([1]).to(device)
204
  output = unet(test_input, test_labels, test_time)
205
  print(f"Model test output shape: {output.shape}")
 
212
 
213
  diffusion_model.eval()
214
  return diffusion_model
215
+
216
+ MODEL_NAME = "model_weights.pth"
217
+ model_path = MODEL_NAME
218
+ print("Loading model...")
219
+ try:
220
+ loaded_model = load_model(model_path, device)
221
+ print("Model loaded successfully!")
222
+ except Exception as e:
223
+ print(f"Failed to load model: {e}")
224
+ # Create a dummy model if loading fails
225
+ print("Creating dummy model for demonstration")
226
+ loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES)).to(device)
227
 
228
  def cancel_generation():
229
  cancel_event.set()
 
233
  global loaded_model
234
  cancel_event.clear()
235
 
 
236
  if num_images < 1 or num_images > 10:
237
  raise gr.Error("Number of images must be between 1 and 10")
238
 
 
250
  raise gr.Error("Generation was cancelled by user")
251
 
252
  with torch.no_grad():
 
253
  images = loaded_model.sample(
254
  num_images=num_images,
255
  timesteps=TIMESTEPS,
 
263
  if images is None:
264
  return None, None
265
 
 
 
266
  processed_images = []
267
  for img in images:
 
268
  img_np = img.cpu().numpy().transpose(1, 2, 0)
269
  img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
 
 
270
  pil_img = Image.fromarray(img_np)
271
  processed_images.append(pil_img)
272
 
 
273
  if num_images == 1:
274
  return processed_images[0], processed_images
275
  else:
 
280
  raise gr.Error(f"Generation failed: {str(e)}")
281
  finally:
282
  torch.cuda.empty_cache()
283
+
284
  # Gradio UI
285
  with gr.Blocks(theme=gr.themes.Soft(
286
  primary_hue="violet",