Vedansh-7 commited on
Commit
fcd2735
·
1 Parent(s): fcc13fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -46
app.py CHANGED
@@ -5,37 +5,42 @@ from PIL import Image
5
  import numpy as np
6
  import math
7
  import os
 
 
8
 
9
  # Constants
10
  IMG_SIZE = 128
11
- TIMESTEPS = 300
12
  NUM_CLASSES = 2
13
 
14
- # Device configuration
 
 
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
- # 1. Sinusoidal Embeddings
18
  class SinusoidalPositionEmbeddings(nn.Module):
19
  def __init__(self, dim):
20
  super().__init__()
21
  self.dim = dim
22
  half_dim = dim // 2
23
  emb = math.log(10000) / (half_dim - 1)
24
- emb = torch.exp(torch.arange(half_dim) * -emb)
25
  self.register_buffer('embeddings', emb)
26
 
27
  def forward(self, time):
28
- device = time.device
29
  embeddings = self.embeddings.to(device)
30
- embeddings = time[:, None] * embeddings[None, :]
31
  return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
32
 
33
- # 2. UNet Model (matches your original architecture exactly)
34
  class UNet(nn.Module):
35
  def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
36
  super().__init__()
37
  self.num_classes = num_classes
38
  self.label_embedding = nn.Embedding(num_classes, time_dim)
 
39
  self.time_mlp = nn.Sequential(
40
  SinusoidalPositionEmbeddings(time_dim),
41
  nn.Linear(time_dim, time_dim),
@@ -43,16 +48,16 @@ class UNet(nn.Module):
43
  nn.Linear(time_dim, time_dim)
44
  )
45
 
46
- # Encoder (matches your original channel sizes)
47
  self.inc = self.double_conv(in_channels, 64)
48
  self.down1 = self.down(64 + time_dim * 2, 128)
49
  self.down2 = self.down(128 + time_dim * 2, 256)
50
  self.down3 = self.down(256 + time_dim * 2, 512)
51
 
52
- # Bottleneck (matches your original)
53
  self.bottleneck = self.double_conv(512 + time_dim * 2, 1024)
54
 
55
- # Decoder (matches your original upsampling structure)
56
  self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2)
57
  self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256)
58
 
@@ -79,7 +84,6 @@ class UNet(nn.Module):
79
  )
80
 
81
  def forward(self, x, labels, time):
82
- # Matches your original forward pass exactly
83
  label_indices = torch.argmax(labels, dim=1)
84
  label_emb = self.label_embedding(label_indices)
85
  t_emb = self.time_mlp(time)
@@ -118,15 +122,14 @@ class UNet(nn.Module):
118
 
119
  return self.outc(x)
120
 
121
- # 3. Diffusion Model (matches your original implementation)
122
  class DiffusionModel(nn.Module):
123
- def __init__(self, model, timesteps=500, time_dim=256):
124
  super().__init__()
125
  self.model = model
126
  self.timesteps = timesteps
127
  self.time_dim = time_dim
128
 
129
- # Linear beta schedule (matches your original)
130
  scale = 1000 / timesteps
131
  beta_start = scale * 0.0001
132
  beta_end = scale * 0.02
@@ -149,8 +152,7 @@ class DiffusionModel(nn.Module):
149
  return predicted_noise, noise, t
150
 
151
  @torch.no_grad()
152
- def sample(self, num_images, img_size, num_classes, labels, device):
153
- # Matches your original sampling exactly
154
  x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
155
 
156
  if labels.ndim == 1:
@@ -161,6 +163,9 @@ class DiffusionModel(nn.Module):
161
  labels = labels.to(device)
162
 
163
  for t in reversed(range(self.timesteps)):
 
 
 
164
  t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
165
  predicted_noise = self.model(x_t, labels, t_tensor)
166
 
@@ -177,10 +182,13 @@ class DiffusionModel(nn.Module):
177
  noise = torch.zeros_like(x_t)
178
 
179
  x_t = mean + torch.sqrt(variance) * noise
 
 
 
180
 
181
  x_0 = torch.clamp(x_t, -1., 1.)
182
 
183
- # Normalization matching your original code
184
  mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
185
  std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
186
  x_0 = std * x_0 + mean
@@ -188,7 +196,6 @@ class DiffusionModel(nn.Module):
188
 
189
  return x_0
190
 
191
- # 4. Model Loading (with improved error handling)
192
  def load_model(model_path, device):
193
  unet_model = UNet(num_classes=NUM_CLASSES).to(device)
194
  diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
@@ -228,37 +235,164 @@ def load_model(model_path, device):
228
  diffusion_model.eval()
229
  return diffusion_model
230
 
231
- # 5. Gradio Interface (matches your original)
232
- def generate_image(label_str):
 
 
 
 
 
 
 
 
 
233
  label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
234
  if label_str not in label_map:
235
- raise gr.Error("Invalid label selected.")
236
-
237
- label_index = label_map[label_str]
238
- labels_to_generate = torch.zeros(1, 2).to(device)
239
- labels_to_generate[:, label_index] = 1
240
-
241
- generated_images_tensor = loaded_model.sample(
242
- 1, IMG_SIZE, NUM_CLASSES, labels_to_generate, device
243
- )
244
-
245
- img_np = generated_images_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
246
- img_pil = Image.fromarray((img_np * 255).astype(np.uint8), 'RGB')
247
- return img_pil
248
-
249
- # Main Execution
250
- if __name__ == "__main__":
251
- # Load model
252
- model_path = "model_weights.pth" # Match your filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  loaded_model = load_model(model_path, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
- # Create interface
256
- iface = gr.Interface(
257
- fn=generate_image,
258
- inputs=gr.Dropdown(["Pneumonia", "Pneumothorax"], label="Select Condition"),
259
- outputs=gr.Image(type="pil", label="Generated X-ray Image"),
260
- title="CheXpert X-ray Image Generator",
261
- description="Generate synthetic chest X-ray images conditioned on selected conditions (Pneumonia or Pneumothorax) using a diffusion model."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  )
263
 
264
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import numpy as np
6
  import math
7
  import os
8
+ from threading import Event
9
+ import traceback
10
 
11
  # Constants
12
  IMG_SIZE = 128
13
+ TIMESTEPS = 300 # From second code
14
  NUM_CLASSES = 2
15
 
16
+ # Global Cancellation Flag
17
+ cancel_event = Event()
18
+
19
+ # Device Configuration
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
+ # --- Model Definitions ---
23
  class SinusoidalPositionEmbeddings(nn.Module):
24
  def __init__(self, dim):
25
  super().__init__()
26
  self.dim = dim
27
  half_dim = dim // 2
28
  emb = math.log(10000) / (half_dim - 1)
29
+ emb = torch.exp(torch.arange(half_dim) * -emb) # From second code (no dtype specified)
30
  self.register_buffer('embeddings', emb)
31
 
32
  def forward(self, time):
33
+ device = time.device # From second code
34
  embeddings = self.embeddings.to(device)
35
+ embeddings = time[:, None] * embeddings[None, :] # From second code
36
  return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
37
 
 
38
  class UNet(nn.Module):
39
  def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
40
  super().__init__()
41
  self.num_classes = num_classes
42
  self.label_embedding = nn.Embedding(num_classes, time_dim)
43
+
44
  self.time_mlp = nn.Sequential(
45
  SinusoidalPositionEmbeddings(time_dim),
46
  nn.Linear(time_dim, time_dim),
 
48
  nn.Linear(time_dim, time_dim)
49
  )
50
 
51
+ # Encoder
52
  self.inc = self.double_conv(in_channels, 64)
53
  self.down1 = self.down(64 + time_dim * 2, 128)
54
  self.down2 = self.down(128 + time_dim * 2, 256)
55
  self.down3 = self.down(256 + time_dim * 2, 512)
56
 
57
+ # Bottleneck
58
  self.bottleneck = self.double_conv(512 + time_dim * 2, 1024)
59
 
60
+ # Decoder
61
  self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2)
62
  self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256)
63
 
 
84
  )
85
 
86
  def forward(self, x, labels, time):
 
87
  label_indices = torch.argmax(labels, dim=1)
88
  label_emb = self.label_embedding(label_indices)
89
  t_emb = self.time_mlp(time)
 
122
 
123
  return self.outc(x)
124
 
 
125
  class DiffusionModel(nn.Module):
126
+ def __init__(self, model, timesteps=TIMESTEPS, time_dim=256):
127
  super().__init__()
128
  self.model = model
129
  self.timesteps = timesteps
130
  self.time_dim = time_dim
131
 
132
+ # Linear beta schedule with scaling from second code
133
  scale = 1000 / timesteps
134
  beta_start = scale * 0.0001
135
  beta_end = scale * 0.02
 
152
  return predicted_noise, noise, t
153
 
154
  @torch.no_grad()
155
+ def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
 
156
  x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
157
 
158
  if labels.ndim == 1:
 
163
  labels = labels.to(device)
164
 
165
  for t in reversed(range(self.timesteps)):
166
+ if cancel_event.is_set():
167
+ return None
168
+
169
  t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
170
  predicted_noise = self.model(x_t, labels, t_tensor)
171
 
 
182
  noise = torch.zeros_like(x_t)
183
 
184
  x_t = mean + torch.sqrt(variance) * noise
185
+
186
+ if progress_callback:
187
+ progress_callback((self.timesteps - t) / self.timesteps)
188
 
189
  x_0 = torch.clamp(x_t, -1., 1.)
190
 
191
+ # Normalization
192
  mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
193
  std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
194
  x_0 = std * x_0 + mean
 
196
 
197
  return x_0
198
 
 
199
  def load_model(model_path, device):
200
  unet_model = UNet(num_classes=NUM_CLASSES).to(device)
201
  diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
 
235
  diffusion_model.eval()
236
  return diffusion_model
237
 
238
+ def cancel_generation():
239
+ cancel_event.set()
240
+ return "Generation cancelled"
241
+
242
+ def generate_images(label_str, num_images, progress=gr.Progress()):
243
+ global loaded_model
244
+ cancel_event.clear()
245
+
246
+ if num_images < 1 or num_images > 10:
247
+ raise gr.Error("Number of images must be between 1 and 10")
248
+
249
  label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
250
  if label_str not in label_map:
251
+ raise gr.Error("Invalid condition selected")
252
+
253
+ labels = torch.zeros(num_images, NUM_CLASSES)
254
+ labels[:, label_map[label_str]] = 1
255
+
256
+ try:
257
+ def progress_callback(progress_val):
258
+ progress(progress_val, desc="Generating...")
259
+ if cancel_event.is_set():
260
+ raise gr.Error("Generation was cancelled by user")
261
+
262
+ with torch.no_grad():
263
+ images = loaded_model.sample(
264
+ num_images=num_images,
265
+ img_size=IMG_SIZE,
266
+ num_classes=NUM_CLASSES,
267
+ labels=labels,
268
+ device=device,
269
+ progress_callback=progress_callback
270
+ )
271
+
272
+ if images is None:
273
+ return None, None
274
+
275
+ processed_images = []
276
+ for img in images:
277
+ img_np = img.cpu().permute(1, 2, 0).numpy()
278
+ img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
279
+ pil_img = Image.fromarray(img_np)
280
+ processed_images.append(pil_img)
281
+
282
+ if num_images == 1:
283
+ return processed_images[0], processed_images
284
+ else:
285
+ return None, processed_images
286
+
287
+ except Exception as e:
288
+ traceback.print_exc()
289
+ raise gr.Error(f"Generation failed: {str(e)}")
290
+ finally:
291
+ torch.cuda.empty_cache()
292
+
293
+ # Load model
294
+ MODEL_NAME = "model_weights.pth"
295
+ model_path = MODEL_NAME
296
+ print("Loading model...")
297
+ try:
298
  loaded_model = load_model(model_path, device)
299
+ print("Model loaded successfully!")
300
+ except Exception as e:
301
+ print(f"Failed to load model: {e}")
302
+ print("Creating dummy model for demonstration")
303
+ loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES), timesteps=TIMESTEPS).to(device)
304
+
305
+ # Gradio UI (from first code)
306
+ with gr.Blocks(theme=gr.themes.Soft(
307
+ primary_hue="violet",
308
+ neutral_hue="slate",
309
+ font=[gr.themes.GoogleFont("Poppins")],
310
+ text_size="md"
311
+ )) as demo:
312
+ gr.Markdown("""
313
+ <center>
314
+ <h1>Synthetic X-ray Generator</h1>
315
+ <p><em>Generate synthetic chest X-rays conditioned on pathology</em></p>
316
+ </center>
317
+ """)
318
+
319
+ with gr.Row():
320
+ with gr.Column(scale=1):
321
+ condition = gr.Dropdown(
322
+ ["Pneumonia", "Pneumothorax"],
323
+ label="Select Condition",
324
+ value="Pneumonia",
325
+ interactive=True
326
+ )
327
+ num_images = gr.Slider(
328
+ 1, 10, value=1, step=1,
329
+ label="Number of Images",
330
+ interactive=True
331
+ )
332
+
333
+ with gr.Row():
334
+ submit_btn = gr.Button("Generate", variant="primary")
335
+ cancel_btn = gr.Button("Cancel", variant="stop")
336
+
337
+ gr.Markdown("""
338
+ <div style="text-align: center; margin-top: 10px;">
339
+ <small>Note: Generation may take several seconds per image</small>
340
+ </div>
341
+ """)
342
+
343
+ with gr.Column(scale=2):
344
+ with gr.Tabs():
345
+ with gr.TabItem("Output", id="output_tab"):
346
+ single_image = gr.Image(
347
+ label="Generated X-ray",
348
+ height=400,
349
+ visible=True
350
+ )
351
+ gallery = gr.Gallery(
352
+ label="Generated X-rays",
353
+ columns=3,
354
+ height="auto",
355
+ object_fit="contain",
356
+ visible=False
357
+ )
358
 
359
+ def update_ui_based_on_count(num_images):
360
+ if num_images == 1:
361
+ return {
362
+ single_image: gr.update(visible=True),
363
+ gallery: gr.update(visible=False)
364
+ }
365
+ else:
366
+ return {
367
+ single_image: gr.update(visible=False),
368
+ gallery: gr.update(visible=True)
369
+ }
370
+
371
+ num_images.change(
372
+ fn=update_ui_based_on_count,
373
+ inputs=num_images,
374
+ outputs=[single_image, gallery]
375
+ )
376
+
377
+ submit_btn.click(
378
+ fn=generate_images,
379
+ inputs=[condition, num_images],
380
+ outputs=[single_image, gallery]
381
  )
382
 
383
+ cancel_btn.click(
384
+ fn=cancel_generation,
385
+ outputs=None
386
+ )
387
+
388
+ demo.css = """
389
+ .gradio-container {
390
+ background: linear-gradient(135deg, #f5f7fa 0%, #e4e8f0 100%);
391
+ }
392
+ .gallery-container {
393
+ background-color: white !important;
394
+ }
395
+ """
396
+
397
+ if __name__ == "__main__":
398
+ demo.launch(server_name="0.0.0.0", server_port=7860)