Vedansh-7 commited on
Commit
06a1915
·
verified ·
1 Parent(s): 48cbe75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -122
app.py CHANGED
@@ -19,25 +19,20 @@ cancel_event = Event()
19
  # Device Configuration
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
- # --- Model Definitions (from second file) ---
23
  class SinusoidalPositionEmbeddings(nn.Module):
24
  def __init__(self, dim):
25
  super().__init__()
26
  self.dim = dim
27
- self.register_buffer('embeddings', self._precompute_embeddings(dim))
28
-
29
- def _precompute_embeddings(self, dim):
30
  half_dim = dim // 2
31
  emb = math.log(10000) / (half_dim - 1)
32
- emb = torch.exp(torch.arange(half_dim) * -emb)
33
- return emb
34
 
35
  def forward(self, time):
36
- device = time.device
37
- embeddings = self.embeddings.to(device)
38
- embeddings = time[:, None] * embeddings[None, :]
39
- output = torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
40
- return output
41
 
42
  class UNet(nn.Module):
43
  def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
@@ -125,95 +120,97 @@ class UNet(nn.Module):
125
  return output
126
 
127
  class DiffusionModel(nn.Module):
128
- def __init__(self, model, timesteps=500, time_dim=256):
129
  super().__init__()
130
  self.model = model
131
  self.timesteps = timesteps
132
  self.time_dim = time_dim
133
 
134
- self.betas = self.linear_schedule(timesteps)
135
- self.alphas = 1. - self.betas
136
- self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0).float())
137
-
138
- def linear_schedule(self, timesteps):
139
  scale = 1000 / timesteps
140
  beta_start = scale * 0.0001
141
  beta_end = scale * 0.02
142
- return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
143
-
144
- def forward_diffusion(self, x_0, t, noise):
145
- x_0 = x_0.float()
146
- noise = noise.float()
147
- alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1)
148
- x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1. - alpha_bar_t) * noise
149
- return x_t
150
-
151
- def forward(self, x_0, labels):
152
- t = torch.randint(0, self.timesteps, (x_0.shape[0],), device=x_0.device).long()
153
- noise = torch.randn_like(x_0)
154
- x_t = self.forward_diffusion(x_0, t, noise)
155
- predicted_noise = self.model(x_t, labels, t.float())
156
- return predicted_noise, noise, t
157
-
158
- @torch.no_grad()
159
- def sample(model, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
160
- x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
161
-
162
- if labels.ndim == 1:
163
- labels_one_hot = torch.zeros(num_images, num_classes).to(device)
164
- labels_one_hot[torch.arange(num_images), labels] = 1
165
- labels = labels_one_hot
166
- else:
167
- labels = labels.to(device)
168
-
169
- for t in reversed(range(timesteps)):
170
- if cancel_event.is_set():
171
- return None
172
-
173
- t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
174
-
175
- predicted_noise = model.model(x_t, labels, t_tensor)
176
-
177
- beta_t = model.betas[t].to(device)
178
- alpha_t = model.alphas[t].to(device)
179
- alpha_bar_t = model.alpha_bars[t].to(device)
180
-
181
- mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
182
- variance = beta_t
183
-
184
- if t > 0:
185
- noise = torch.randn_like(x_t)
186
- else:
187
- noise = torch.zeros_like(x_t)
188
 
189
- x_t = mean + torch.sqrt(variance) * noise
 
 
 
190
 
191
- if progress_callback:
192
- progress_callback((timesteps - t) / timesteps)
193
-
194
- x_0 = torch.clamp(x_t, -1., 1.)
 
 
 
195
 
196
- mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
197
- std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
198
- x_0 = std * x_0 + mean
199
- x_0 = torch.clamp(x_0, 0., 1.)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
- return x_0
 
 
202
 
203
  def load_model(model_path, device):
204
- unet_model = UNet(num_classes=NUM_CLASSES).to(device)
205
- diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
206
 
207
- try:
208
- checkpoint = torch.load(model_path, map_location=device)
209
- if 'model_state_dict' in checkpoint:
210
- diffusion_model.model.load_state_dict(checkpoint['model_state_dict'])
211
- else:
212
- diffusion_model.model.load_state_dict(checkpoint)
213
- print(f"Successfully loaded model from {model_path}")
214
- except Exception as e:
215
- print(f"Error loading model: {e}")
216
- print("Using randomly initialized weights")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  diffusion_model.eval()
219
  return diffusion_model
@@ -222,32 +219,6 @@ def cancel_generation():
222
  cancel_event.set()
223
  return "Generation cancelled"
224
 
225
- def generate_single_image(label_str):
226
- label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
227
- try:
228
- label_index = label_map[label_str]
229
- except KeyError:
230
- raise gr.Error(f"Invalid label '{label_str}'. Please select either 'Pneumonia' or 'Pneumothorax'.")
231
-
232
- labels = torch.zeros(1, NUM_CLASSES, device=device)
233
- labels[0, label_index] = 1
234
-
235
- with torch.no_grad():
236
- generated_image = sample(
237
- model=loaded_model,
238
- num_images=1,
239
- timesteps=TIMESTEPS,
240
- img_size=IMG_SIZE,
241
- num_classes=NUM_CLASSES,
242
- labels=labels,
243
- device=device
244
- )
245
-
246
- img_np = generated_image.squeeze(0).cpu().permute(1, 2, 0).numpy()
247
- img_np = np.clip(img_np, 0, 1)
248
- img_pil = Image.fromarray((img_np * 255).astype(np.uint8))
249
-
250
- return img_pil
251
  def generate_images(label_str, num_images, progress=gr.Progress()):
252
  global loaded_model
253
  cancel_event.clear()
@@ -260,7 +231,7 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
260
  if label_str not in label_map:
261
  raise gr.Error("Invalid condition selected")
262
 
263
- labels = torch.zeros(num_images, NUM_CLASSES, device=device)
264
  labels[:, label_map[label_str]] = 1
265
 
266
  try:
@@ -270,10 +241,11 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
270
  raise gr.Error("Generation was cancelled by user")
271
 
272
  with torch.no_grad():
273
- images = sample(
274
- model=loaded_model,
 
 
275
  num_images=num_images,
276
- timesteps=TIMESTEPS,
277
  img_size=IMG_SIZE,
278
  num_classes=NUM_CLASSES,
279
  labels=labels,
@@ -284,15 +256,21 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
284
  if images is None:
285
  return None, None
286
 
287
- # Process all generated images
 
 
288
  processed_images = []
289
  for img in images:
290
- img_np = img.cpu().permute(1, 2, 0).numpy()
291
- img_np = np.clip(img_np, 0, 1)
292
- pil_img = Image.fromarray((img_np * 255).astype(np.uint8))
 
 
 
 
293
  processed_images.append(pil_img)
294
 
295
- # Return both single image and gallery based on count
296
  if num_images == 1:
297
  return processed_images[0], processed_images
298
  else:
@@ -317,7 +295,7 @@ print("Loading model...")
317
  loaded_model = load_model(model_path, device)
318
  print("Model loaded successfully!")
319
 
320
- # Unified Gradio UI
321
  with gr.Blocks(theme=gr.themes.Soft(
322
  primary_hue="violet",
323
  neutral_hue="slate",
@@ -356,7 +334,6 @@ with gr.Blocks(theme=gr.themes.Soft(
356
  """)
357
 
358
  with gr.Column(scale=2):
359
- # Unified output display that adapts to single/batch
360
  with gr.Tabs():
361
  with gr.TabItem("Output", id="output_tab"):
362
  single_image = gr.Image(
 
19
  # Device Configuration
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
+ # --- Model Definitions ---
23
  class SinusoidalPositionEmbeddings(nn.Module):
24
  def __init__(self, dim):
25
  super().__init__()
26
  self.dim = dim
 
 
 
27
  half_dim = dim // 2
28
  emb = math.log(10000) / (half_dim - 1)
29
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
30
+ self.register_buffer('embeddings', emb)
31
 
32
  def forward(self, time):
33
+ embeddings = self.embeddings.to(time.device)
34
+ embeddings = time.float()[:, None] * embeddings[None, :]
35
+ return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
 
 
36
 
37
  class UNet(nn.Module):
38
  def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
 
120
  return output
121
 
122
  class DiffusionModel(nn.Module):
123
+ def __init__(self, model, timesteps=TIMESTEPS, time_dim=256):
124
  super().__init__()
125
  self.model = model
126
  self.timesteps = timesteps
127
  self.time_dim = time_dim
128
 
129
+ # Fix 1: Ensure consistent float32 types
 
 
 
 
130
  scale = 1000 / timesteps
131
  beta_start = scale * 0.0001
132
  beta_end = scale * 0.02
133
+ self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
134
+ self.alphas = 1. - self.betas
135
+ self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ @torch.no_grad()
138
+ def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
139
+ # Initialize with noise
140
+ x_t = torch.randn((num_images, 3, img_size, img_size), device=device, dtype=torch.float32)
141
 
142
+ # Convert labels to proper format
143
+ if labels.ndim == 1:
144
+ labels_one_hot = torch.zeros(num_images, num_classes, device=device)
145
+ labels_one_hot[torch.arange(num_images), labels] = 1
146
+ labels = labels_one_hot
147
+ else:
148
+ labels = labels.to(device)
149
 
150
+ for i in reversed(range(0, self.timesteps)):
151
+ if cancel_event.is_set():
152
+ return None
153
+
154
+ t = torch.full((num_images,), i, device=device, dtype=torch.long)
155
+
156
+ # Model prediction with type stability
157
+ pred_noise = self.model(x_t, labels, t.float())
158
+
159
+ # Calculate diffusion parameters
160
+ beta_t = self.betas[t].view(-1, 1, 1, 1).to(device)
161
+ alpha_t = self.alphas[t].view(-1, 1, 1, 1).to(device)
162
+ alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1).to(device)
163
+
164
+ # Improved denoising step (Fix 2)
165
+ if i > 0:
166
+ noise = torch.randn_like(x_t)
167
+ else:
168
+ noise = torch.zeros_like(x_t)
169
+
170
+ x_t = (x_t - (1 - alpha_t)/torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_t)
171
+ x_t += noise * torch.sqrt(beta_t)
172
+
173
+ if progress_callback:
174
+ progress_callback((self.timesteps - i) / self.timesteps)
175
 
176
+ # Fix 3: Simplified scaling
177
+ x_t = torch.clamp(x_t, -1., 1.)
178
+ return (x_t + 1) / 2 # Scale to [0,1]
179
 
180
  def load_model(model_path, device):
181
+ unet = UNet(num_classes=NUM_CLASSES).to(device)
182
+ diffusion_model = DiffusionModel(unet).to(device)
183
 
184
+ if os.path.exists(model_path):
185
+ try:
186
+ checkpoint = torch.load(model_path, map_location=device)
187
+
188
+ # Handle both full model and state_dict loading
189
+ if 'model_state_dict' in checkpoint:
190
+ state_dict = checkpoint['model_state_dict']
191
+ else:
192
+ state_dict = checkpoint
193
+
194
+ # Handle both prefixed and non-prefixed state dicts
195
+ if all(k.startswith('model.') for k in state_dict.keys()):
196
+ state_dict = {k[6:]: v for k, v in state_dict.items()}
197
+
198
+ unet.load_state_dict(state_dict, strict=False)
199
+ print("Model loaded successfully")
200
+
201
+ # Verify model loading
202
+ test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
203
+ test_labels = torch.zeros(1, NUM_CLASSES).to(device)
204
+ test_labels[0, 0] = 1
205
+ test_time = torch.tensor([1]).to(device)
206
+ output = unet(test_input, test_labels, test_time)
207
+ print(f"Model test output shape: {output.shape}")
208
+
209
+ except Exception as e:
210
+ traceback.print_exc()
211
+ raise ValueError(f"Error loading model: {str(e)}")
212
+ else:
213
+ raise FileNotFoundError(f"Model weights not found at {model_path}")
214
 
215
  diffusion_model.eval()
216
  return diffusion_model
 
219
  cancel_event.set()
220
  return "Generation cancelled"
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  def generate_images(label_str, num_images, progress=gr.Progress()):
223
  global loaded_model
224
  cancel_event.clear()
 
231
  if label_str not in label_map:
232
  raise gr.Error("Invalid condition selected")
233
 
234
+ labels = torch.zeros(num_images, NUM_CLASSES, device=device, dtype=torch.float32)
235
  labels[:, label_map[label_str]] = 1
236
 
237
  try:
 
241
  raise gr.Error("Generation was cancelled by user")
242
 
243
  with torch.no_grad():
244
+ print(f"Generating {num_images} images for {label_str}")
245
+ print(f"Labels shape: {labels.shape}, device: {labels.device}")
246
+
247
+ images = loaded_model.sample(
248
  num_images=num_images,
 
249
  img_size=IMG_SIZE,
250
  num_classes=NUM_CLASSES,
251
  labels=labels,
 
256
  if images is None:
257
  return None, None
258
 
259
+ # Diagnostic print
260
+ print(f"Generated images range: {images.min().item():.3f}, {images.max().item():.3f}")
261
+
262
  processed_images = []
263
  for img in images:
264
+ # Fix 3: Improved image conversion
265
+ img_np = (img.cpu().numpy().transpose(1, 2, 0) * 255).clip(0, 255).astype(np.uint8)
266
+ print(f"Image range after conversion: {img_np.min()}, {img_np.max()}")
267
+
268
+ if img_np.shape[2] == 1: # Handle grayscale if needed
269
+ img_np = img_np.squeeze(-1)
270
+ pil_img = Image.fromarray(img_np)
271
  processed_images.append(pil_img)
272
 
273
+ # Return appropriate outputs based on count
274
  if num_images == 1:
275
  return processed_images[0], processed_images
276
  else:
 
295
  loaded_model = load_model(model_path, device)
296
  print("Model loaded successfully!")
297
 
298
+ # Gradio UI
299
  with gr.Blocks(theme=gr.themes.Soft(
300
  primary_hue="violet",
301
  neutral_hue="slate",
 
334
  """)
335
 
336
  with gr.Column(scale=2):
 
337
  with gr.Tabs():
338
  with gr.TabItem("Output", id="output_tab"):
339
  single_image = gr.Image(