Vedansh-7 commited on
Commit
d61d482
·
verified ·
1 Parent(s): e604841

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -36
app.py CHANGED
@@ -5,16 +5,21 @@ from PIL import Image
5
  import numpy as np
6
  import math
7
  import os
 
 
8
 
9
- # Constants (update these to match your training config)
10
  IMG_SIZE = 128
11
  TIMESTEPS = 300
12
  NUM_CLASSES = 2
13
 
14
- # Define the device
 
 
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
- # Define the SinusoidalPositionEmbeddings class
18
  class SinusoidalPositionEmbeddings(nn.Module):
19
  def __init__(self, dim):
20
  super().__init__()
@@ -34,11 +39,9 @@ class SinusoidalPositionEmbeddings(nn.Module):
34
  output = torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
35
  return output
36
 
37
- # Define the UNet class
38
  class UNet(nn.Module):
39
  def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
40
  super().__init__()
41
-
42
  self.num_classes = num_classes
43
  self.label_embedding = nn.Embedding(num_classes, time_dim)
44
 
@@ -121,7 +124,6 @@ class UNet(nn.Module):
121
  output = self.outc(x)
122
  return output
123
 
124
- # Define the DiffusionModel class
125
  class DiffusionModel(nn.Module):
126
  def __init__(self, model, timesteps=500, time_dim=256):
127
  super().__init__()
@@ -154,7 +156,7 @@ class DiffusionModel(nn.Module):
154
  return predicted_noise, noise, t
155
 
156
  @torch.no_grad()
157
- def sample(model, num_images, timesteps, img_size, num_classes, labels, device):
158
  x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
159
 
160
  if labels.ndim == 1:
@@ -165,6 +167,9 @@ def sample(model, num_images, timesteps, img_size, num_classes, labels, device):
165
  labels = labels.to(device)
166
 
167
  for t in reversed(range(timesteps)):
 
 
 
168
  t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
169
 
170
  predicted_noise = model.model(x_t, labels, t_tensor)
@@ -182,6 +187,9 @@ def sample(model, num_images, timesteps, img_size, num_classes, labels, device):
182
  noise = torch.zeros_like(x_t)
183
 
184
  x_t = mean + torch.sqrt(variance) * noise
 
 
 
185
 
186
  x_0 = torch.clamp(x_t, -1., 1.)
187
 
@@ -192,14 +200,12 @@ def sample(model, num_images, timesteps, img_size, num_classes, labels, device):
192
 
193
  return x_0
194
 
195
- # Load the trained model with improved error handling
196
  def load_model(model_path, device):
197
  unet_model = UNet(num_classes=NUM_CLASSES).to(device)
198
  diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
199
 
200
  try:
201
  checkpoint = torch.load(model_path, map_location=device)
202
- # Handle both full model and state_dict loading
203
  if 'model_state_dict' in checkpoint:
204
  diffusion_model.model.load_state_dict(checkpoint['model_state_dict'])
205
  else:
@@ -212,19 +218,20 @@ def load_model(model_path, device):
212
  diffusion_model.eval()
213
  return diffusion_model
214
 
215
- # Improved image generation function
216
- def generate_image(label_str):
 
 
 
217
  label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
218
  try:
219
  label_index = label_map[label_str]
220
  except KeyError:
221
  raise gr.Error(f"Invalid label '{label_str}'. Please select either 'Pneumonia' or 'Pneumothorax'.")
222
 
223
- # Create one-hot encoded label
224
  labels = torch.zeros(1, NUM_CLASSES, device=device)
225
  labels[0, label_index] = 1
226
 
227
- # Generate image
228
  with torch.no_grad():
229
  generated_image = sample(
230
  model=loaded_model,
@@ -236,39 +243,155 @@ def generate_image(label_str):
236
  device=device
237
  )
238
 
239
- # Convert to PIL Image
240
  img_np = generated_image.squeeze(0).cpu().permute(1, 2, 0).numpy()
241
- img_np = np.clip(img_np, 0, 1) # Ensure proper range
242
  img_pil = Image.fromarray((img_np * 255).astype(np.uint8))
243
 
244
  return img_pil
245
 
246
- # Model paths (update these for your deployment)
247
- MODEL_DIR = "models"
248
- MODEL_NAME = "diffusion_unet_xray.pth" # Update with your actual filename
249
- model_path = os.path.join(MODEL_DIR, MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  # Load model
 
 
 
252
  print("Loading model...")
253
  loaded_model = load_model(model_path, device)
254
  print("Model loaded successfully!")
255
 
256
- # Gradio interface
257
- iface = gr.Interface(
258
- fn=generate_image,
259
- inputs=gr.Dropdown(
260
- choices=["Pneumonia", "Pneumothorax"],
261
- label="Select Condition",
262
- value="Pneumonia" # Default value
263
- ),
264
- outputs=gr.Image(
265
- type="pil",
266
- label="Generated X-ray Image"
267
- ),
268
- title="Medical X-ray Image Generator",
269
- description="Generate synthetic chest X-ray images using a diffusion model. Select a condition to generate.",
270
- examples=[["Pneumonia"], ["Pneumothorax"]]
271
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  if __name__ == "__main__":
274
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
5
  import numpy as np
6
  import math
7
  import os
8
+ from threading import Event
9
+ import traceback
10
 
11
+ # Constants
12
  IMG_SIZE = 128
13
  TIMESTEPS = 300
14
  NUM_CLASSES = 2
15
 
16
+ # Global Cancellation Flag
17
+ cancel_event = Event()
18
+
19
+ # Device Configuration
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
+ # --- Model Definitions (from second file) ---
23
  class SinusoidalPositionEmbeddings(nn.Module):
24
  def __init__(self, dim):
25
  super().__init__()
 
39
  output = torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
40
  return output
41
 
 
42
  class UNet(nn.Module):
43
  def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
44
  super().__init__()
 
45
  self.num_classes = num_classes
46
  self.label_embedding = nn.Embedding(num_classes, time_dim)
47
 
 
124
  output = self.outc(x)
125
  return output
126
 
 
127
  class DiffusionModel(nn.Module):
128
  def __init__(self, model, timesteps=500, time_dim=256):
129
  super().__init__()
 
156
  return predicted_noise, noise, t
157
 
158
  @torch.no_grad()
159
+ def sample(model, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
160
  x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
161
 
162
  if labels.ndim == 1:
 
167
  labels = labels.to(device)
168
 
169
  for t in reversed(range(timesteps)):
170
+ if cancel_event.is_set():
171
+ return None
172
+
173
  t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
174
 
175
  predicted_noise = model.model(x_t, labels, t_tensor)
 
187
  noise = torch.zeros_like(x_t)
188
 
189
  x_t = mean + torch.sqrt(variance) * noise
190
+
191
+ if progress_callback:
192
+ progress_callback((timesteps - t) / timesteps)
193
 
194
  x_0 = torch.clamp(x_t, -1., 1.)
195
 
 
200
 
201
  return x_0
202
 
 
203
  def load_model(model_path, device):
204
  unet_model = UNet(num_classes=NUM_CLASSES).to(device)
205
  diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
206
 
207
  try:
208
  checkpoint = torch.load(model_path, map_location=device)
 
209
  if 'model_state_dict' in checkpoint:
210
  diffusion_model.model.load_state_dict(checkpoint['model_state_dict'])
211
  else:
 
218
  diffusion_model.eval()
219
  return diffusion_model
220
 
221
+ def cancel_generation():
222
+ cancel_event.set()
223
+ return "Generation cancelled"
224
+
225
+ def generate_single_image(label_str):
226
  label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
227
  try:
228
  label_index = label_map[label_str]
229
  except KeyError:
230
  raise gr.Error(f"Invalid label '{label_str}'. Please select either 'Pneumonia' or 'Pneumothorax'.")
231
 
 
232
  labels = torch.zeros(1, NUM_CLASSES, device=device)
233
  labels[0, label_index] = 1
234
 
 
235
  with torch.no_grad():
236
  generated_image = sample(
237
  model=loaded_model,
 
243
  device=device
244
  )
245
 
 
246
  img_np = generated_image.squeeze(0).cpu().permute(1, 2, 0).numpy()
247
+ img_np = np.clip(img_np, 0, 1)
248
  img_pil = Image.fromarray((img_np * 255).astype(np.uint8))
249
 
250
  return img_pil
251
 
252
+ def generate_batch_images(label_str, num_images, progress=gr.Progress()):
253
+ global loaded_model
254
+ cancel_event.clear()
255
+
256
+ if num_images < 1 or num_images > 10:
257
+ raise gr.Error("Number of images must be between 1 and 10")
258
+
259
+ label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
260
+ if label_str not in label_map:
261
+ raise gr.Error("Invalid condition selected")
262
+
263
+ labels = torch.zeros(num_images, NUM_CLASSES, device=device)
264
+ labels[:, label_map[label_str]] = 1
265
+
266
+ try:
267
+ def progress_callback(progress_val):
268
+ progress(progress_val, desc="Generating...")
269
+ if cancel_event.is_set():
270
+ raise gr.Error("Generation was cancelled by user")
271
+
272
+ with torch.no_grad():
273
+ images = sample(
274
+ model=loaded_model,
275
+ num_images=num_images,
276
+ timesteps=TIMESTEPS,
277
+ img_size=IMG_SIZE,
278
+ num_classes=NUM_CLASSES,
279
+ labels=labels,
280
+ device=device,
281
+ progress_callback=progress_callback
282
+ )
283
+
284
+ if images is None:
285
+ return None
286
+
287
+ processed_images = []
288
+ for img in images:
289
+ img_np = img.cpu().permute(1, 2, 0).numpy()
290
+ img_np = np.clip(img_np, 0, 1)
291
+ pil_img = Image.fromarray((img_np * 255).astype(np.uint8))
292
+ processed_images.append(pil_img)
293
+
294
+ return processed_images
295
+
296
+ except torch.cuda.OutOfMemoryError:
297
+ torch.cuda.empty_cache()
298
+ raise gr.Error("Out of GPU memory - try generating fewer images")
299
+ except Exception as e:
300
+ traceback.print_exc()
301
+ if str(e) != "Generation was cancelled by user":
302
+ raise gr.Error(f"Generation failed: {str(e)}")
303
+ return None
304
+ finally:
305
+ torch.cuda.empty_cache()
306
 
307
  # Load model
308
+ MODEL_DIR = "models"
309
+ MODEL_NAME = "diffusion_unet_xray.pth"
310
+ model_path = os.path.join(MODEL_DIR, MODEL_NAME)
311
  print("Loading model...")
312
  loaded_model = load_model(model_path, device)
313
  print("Model loaded successfully!")
314
 
315
+ # --- Gradio UI (from first file with modifications) ---
316
+ with gr.Blocks(theme=gr.themes.Soft(
317
+ primary_hue="violet",
318
+ neutral_hue="slate",
319
+ font=[gr.themes.GoogleFont("Poppins")],
320
+ text_size="md"
321
+ )) as demo:
322
+ gr.Markdown("""
323
+ <center>
324
+ <h1>Synthetic X-ray Generator</h1>
325
+ <p><em>Generate synthetic chest X-rays conditioned on pathology</em></p>
326
+ </center>
327
+ """)
328
+
329
+ with gr.Row():
330
+ with gr.Column(scale=1):
331
+ condition = gr.Dropdown(
332
+ ["Pneumonia", "Pneumothorax"],
333
+ label="Select Condition",
334
+ value="Pneumonia",
335
+ interactive=True
336
+ )
337
+ num_images = gr.Slider(
338
+ 1, 10, value=1, step=1,
339
+ label="Number of Images",
340
+ interactive=True
341
+ )
342
+
343
+ with gr.Row():
344
+ submit_btn = gr.Button("Generate", variant="primary")
345
+ cancel_btn = gr.Button("Cancel", variant="stop")
346
+
347
+ gr.Markdown("""
348
+ <div style="text-align: center; margin-top: 10px;">
349
+ <small>Note: Generation may take several seconds per image</small>
350
+ </div>
351
+ """)
352
+
353
+ with gr.Column(scale=2):
354
+ with gr.Tab("Single Image"):
355
+ single_image = gr.Image(
356
+ type="pil",
357
+ label="Generated X-ray",
358
+ height=400
359
+ )
360
+ with gr.Tab("Batch Images"):
361
+ gallery = gr.Gallery(
362
+ label="Generated X-rays",
363
+ columns=3,
364
+ height="auto",
365
+ object_fit="contain"
366
+ )
367
+
368
+ # Single image generation
369
+ condition.change(
370
+ fn=generate_single_image,
371
+ inputs=condition,
372
+ outputs=single_image
373
+ )
374
+
375
+ # Batch image generation
376
+ submit_btn.click(
377
+ fn=generate_batch_images,
378
+ inputs=[condition, num_images],
379
+ outputs=gallery
380
+ )
381
+
382
+ cancel_btn.click(
383
+ fn=cancel_generation,
384
+ outputs=None
385
+ )
386
+
387
+ demo.css = """
388
+ .gradio-container {
389
+ background: linear-gradient(135deg, #f5f7fa 0%, #e4e8f0 100%);
390
+ }
391
+ .gallery-container {
392
+ background-color: white !important;
393
+ }
394
+ """
395
 
396
  if __name__ == "__main__":
397
+ demo.launch(server_name="0.0.0.0", server_port=7860)