Vedansh-7 commited on
Commit
190a6d4
·
verified ·
1 Parent(s): 49fdbe4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -208
app.py CHANGED
@@ -5,41 +5,37 @@ from PIL import Image
5
  import numpy as np
6
  import math
7
  import os
8
- from threading import Event
9
- import traceback
10
 
11
  # Constants
12
  IMG_SIZE = 128
13
- TIMESTEPS = 500
14
  NUM_CLASSES = 2
15
 
16
- # Global Cancellation Flag
17
- cancel_event = Event()
18
-
19
- # Device Configuration
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
- # --- Model Definitions ---
23
  class SinusoidalPositionEmbeddings(nn.Module):
24
  def __init__(self, dim):
25
  super().__init__()
26
  self.dim = dim
27
  half_dim = dim // 2
28
  emb = math.log(10000) / (half_dim - 1)
29
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
30
  self.register_buffer('embeddings', emb)
31
 
32
  def forward(self, time):
33
- embeddings = self.embeddings.to(time.device)
34
- embeddings = time.float()[:, None] * embeddings[None, :]
 
35
  return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
36
 
 
37
  class UNet(nn.Module):
38
  def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
39
  super().__init__()
40
  self.num_classes = num_classes
41
  self.label_embedding = nn.Embedding(num_classes, time_dim)
42
-
43
  self.time_mlp = nn.Sequential(
44
  SinusoidalPositionEmbeddings(time_dim),
45
  nn.Linear(time_dim, time_dim),
@@ -47,13 +43,16 @@ class UNet(nn.Module):
47
  nn.Linear(time_dim, time_dim)
48
  )
49
 
 
50
  self.inc = self.double_conv(in_channels, 64)
51
  self.down1 = self.down(64 + time_dim * 2, 128)
52
  self.down2 = self.down(128 + time_dim * 2, 256)
53
  self.down3 = self.down(256 + time_dim * 2, 512)
54
 
 
55
  self.bottleneck = self.double_conv(512 + time_dim * 2, 1024)
56
 
 
57
  self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2)
58
  self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256)
59
 
@@ -80,6 +79,7 @@ class UNet(nn.Module):
80
  )
81
 
82
  def forward(self, x, labels, time):
 
83
  label_indices = torch.argmax(labels, dim=1)
84
  label_emb = self.label_embedding(label_indices)
85
  t_emb = self.time_mlp(time)
@@ -116,24 +116,41 @@ class UNet(nn.Module):
116
  x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
117
  x = self.upconv3(x)
118
 
119
- output = self.outc(x)
120
- return output
121
 
 
122
  class DiffusionModel(nn.Module):
123
- def __init__(self, model, timesteps=TIMESTEPS):
124
  super().__init__()
125
  self.model = model
126
  self.timesteps = timesteps
127
-
128
- beta_start = 0.0001
129
- beta_end = 0.02
130
- self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
 
 
 
131
  self.alphas = 1. - self.betas
132
- self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  @torch.no_grad()
135
- def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
136
- """Your exact sampling function from Colab"""
137
  x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
138
 
139
  if labels.ndim == 1:
@@ -144,11 +161,7 @@ class DiffusionModel(nn.Module):
144
  labels = labels.to(device)
145
 
146
  for t in reversed(range(self.timesteps)):
147
- if cancel_event.is_set():
148
- return None
149
-
150
- t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float) # Pass time as float
151
-
152
  predicted_noise = self.model(x_t, labels, t_tensor)
153
 
154
  beta_t = self.betas[t].to(device)
@@ -164,12 +177,10 @@ class DiffusionModel(nn.Module):
164
  noise = torch.zeros_like(x_t)
165
 
166
  x_t = mean + torch.sqrt(variance) * noise
167
-
168
- if progress_callback:
169
- progress_callback((self.timesteps - t) / self.timesteps)
170
 
171
  x_0 = torch.clamp(x_t, -1., 1.)
172
 
 
173
  mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
174
  std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
175
  x_0 = std * x_0 + mean
@@ -177,197 +188,74 @@ class DiffusionModel(nn.Module):
177
 
178
  return x_0
179
 
 
180
  def load_model(model_path, device):
181
- unet = UNet(num_classes=NUM_CLASSES).to(device)
182
- diffusion_model = DiffusionModel(unet).to(device)
183
-
184
  if os.path.exists(model_path):
185
- try:
186
- checkpoint = torch.load(model_path, map_location=device)
187
-
188
- if 'model_state_dict' in checkpoint:
189
- state_dict = checkpoint['model_state_dict']
190
- else:
191
- state_dict = checkpoint
 
192
 
193
- if all(k.startswith('model.') for k in state_dict.keys()):
194
- state_dict = {k[6:]: v for k, v in state_dict.items()}
195
 
196
- unet.load_state_dict(state_dict, strict=False)
197
- print("Model loaded successfully")
198
 
199
- test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
200
- test_labels = torch.zeros(1, NUM_CLASSES).to(device)
201
- test_time = torch.tensor([1]).to(device)
202
- output = unet(test_input, test_labels, test_time)
203
- print(f"Model test output shape: {output.shape}")
 
 
 
204
 
205
- except Exception as e:
206
- traceback.print_exc()
207
- raise ValueError(f"Error loading model: {str(e)}")
208
  else:
209
- raise FileNotFoundError(f"Model weights not found at {model_path}")
210
-
 
211
  diffusion_model.eval()
212
  return diffusion_model
213
 
214
- MODEL_NAME = "model_weights.pth"
215
- model_path = MODEL_NAME
216
- print("Loading model...")
217
- try:
218
- loaded_model = load_model(model_path, device)
219
- print("Model loaded successfully!")
220
- except Exception as e:
221
- print(f"Failed to load model: {e}")
222
- print("Creating dummy model for demonstration")
223
- loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES)).to(device)
224
-
225
- def cancel_generation():
226
- cancel_event.set()
227
- return "Generation cancelled"
228
-
229
- def generate_images(label_str, num_images, progress=gr.Progress()):
230
- global loaded_model
231
- cancel_event.clear()
232
-
233
- if num_images < 1 or num_images > 10:
234
- raise gr.Error("Number of images must be between 1 and 10")
235
-
236
  label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
237
  if label_str not in label_map:
238
- raise gr.Error("Invalid condition selected")
239
-
240
- labels = torch.zeros(num_images, NUM_CLASSES)
241
- labels[:, label_map[label_str]] = 1
242
-
243
- try:
244
- def progress_callback(progress_val):
245
- progress(progress_val, desc="Generating...")
246
- if cancel_event.is_set():
247
- raise gr.Error("Generation was cancelled by user")
248
-
249
- with torch.no_grad():
250
- images = loaded_model.sample(
251
- num_images=num_images,
252
- img_size=IMG_SIZE,
253
- num_classes=NUM_CLASSES,
254
- labels=labels,
255
- device=device,
256
- progress_callback=progress_callback
257
- )
258
-
259
- if images is None:
260
- return None, None
261
-
262
- processed_images = []
263
- for img in images:
264
- img_np = img.cpu().permute(1, 2, 0).numpy()
265
- img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
266
- pil_img = Image.fromarray(img_np)
267
- processed_images.append(pil_img)
268
-
269
- if num_images == 1:
270
- return processed_images[0], processed_images
271
- else:
272
- return None, processed_images
273
-
274
- except Exception as e:
275
- traceback.print_exc()
276
- raise gr.Error(f"Generation failed: {str(e)}")
277
- finally:
278
- torch.cuda.empty_cache()
279
-
280
- # Gradio UI
281
- with gr.Blocks(theme=gr.themes.Soft(
282
- primary_hue="violet",
283
- neutral_hue="slate",
284
- font=[gr.themes.GoogleFont("Poppins")],
285
- text_size="md"
286
- )) as demo:
287
- gr.Markdown("""
288
- <center>
289
- <h1>Synthetic X-ray Generator</h1>
290
- <p><em>Generate synthetic chest X-rays conditioned on pathology</em></p>
291
- </center>
292
- """)
293
-
294
- with gr.Row():
295
- with gr.Column(scale=1):
296
- condition = gr.Dropdown(
297
- ["Pneumonia", "Pneumothorax"],
298
- label="Select Condition",
299
- value="Pneumonia",
300
- interactive=True
301
- )
302
- num_images = gr.Slider(
303
- 1, 10, value=1, step=1,
304
- label="Number of Images",
305
- interactive=True
306
- )
307
-
308
- with gr.Row():
309
- submit_btn = gr.Button("Generate", variant="primary")
310
- cancel_btn = gr.Button("Cancel", variant="stop")
311
-
312
- gr.Markdown("""
313
- <div style="text-align: center; margin-top: 10px;">
314
- <small>Note: Generation may take several seconds per image</small>
315
- </div>
316
- """)
317
-
318
- with gr.Column(scale=2):
319
- with gr.Tabs():
320
- with gr.TabItem("Output", id="output_tab"):
321
- single_image = gr.Image(
322
- label="Generated X-ray",
323
- height=400,
324
- visible=True
325
- )
326
- gallery = gr.Gallery(
327
- label="Generated X-rays",
328
- columns=3,
329
- height="auto",
330
- object_fit="contain",
331
- visible=False
332
- )
333
-
334
- def update_ui_based_on_count(num_images):
335
- if num_images == 1:
336
- return {
337
- single_image: gr.update(visible=True),
338
- gallery: gr.update(visible=False)
339
- }
340
- else:
341
- return {
342
- single_image: gr.update(visible=False),
343
- gallery: gr.update(visible=True)
344
- }
345
-
346
- num_images.change(
347
- fn=update_ui_based_on_count,
348
- inputs=num_images,
349
- outputs=[single_image, gallery]
350
- )
351
-
352
- submit_btn.click(
353
- fn=generate_images,
354
- inputs=[condition, num_images],
355
- outputs=[single_image, gallery]
356
- )
357
-
358
- cancel_btn.click(
359
- fn=cancel_generation,
360
- outputs=None
361
  )
362
 
363
- demo.css = """
364
- .gradio-container {
365
- background: linear-gradient(135deg, #f5f7fa 0%, #e4e8f0 100%);
366
- }
367
- .gallery-container {
368
- background-color: white !important;
369
- }
370
- """
371
 
 
372
  if __name__ == "__main__":
373
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import numpy as np
6
  import math
7
  import os
 
 
8
 
9
  # Constants
10
  IMG_SIZE = 128
11
+ TIMESTEPS = 300
12
  NUM_CLASSES = 2
13
 
14
+ # Device configuration
 
 
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ # 1. Sinusoidal Embeddings
18
  class SinusoidalPositionEmbeddings(nn.Module):
19
  def __init__(self, dim):
20
  super().__init__()
21
  self.dim = dim
22
  half_dim = dim // 2
23
  emb = math.log(10000) / (half_dim - 1)
24
+ emb = torch.exp(torch.arange(half_dim) * -emb)
25
  self.register_buffer('embeddings', emb)
26
 
27
  def forward(self, time):
28
+ device = time.device
29
+ embeddings = self.embeddings.to(device)
30
+ embeddings = time[:, None] * embeddings[None, :]
31
  return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
32
 
33
+ # 2. UNet Model (matches your original architecture exactly)
34
  class UNet(nn.Module):
35
  def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
36
  super().__init__()
37
  self.num_classes = num_classes
38
  self.label_embedding = nn.Embedding(num_classes, time_dim)
 
39
  self.time_mlp = nn.Sequential(
40
  SinusoidalPositionEmbeddings(time_dim),
41
  nn.Linear(time_dim, time_dim),
 
43
  nn.Linear(time_dim, time_dim)
44
  )
45
 
46
+ # Encoder (matches your original channel sizes)
47
  self.inc = self.double_conv(in_channels, 64)
48
  self.down1 = self.down(64 + time_dim * 2, 128)
49
  self.down2 = self.down(128 + time_dim * 2, 256)
50
  self.down3 = self.down(256 + time_dim * 2, 512)
51
 
52
+ # Bottleneck (matches your original)
53
  self.bottleneck = self.double_conv(512 + time_dim * 2, 1024)
54
 
55
+ # Decoder (matches your original upsampling structure)
56
  self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2)
57
  self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256)
58
 
 
79
  )
80
 
81
  def forward(self, x, labels, time):
82
+ # Matches your original forward pass exactly
83
  label_indices = torch.argmax(labels, dim=1)
84
  label_emb = self.label_embedding(label_indices)
85
  t_emb = self.time_mlp(time)
 
116
  x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
117
  x = self.upconv3(x)
118
 
119
+ return self.outc(x)
 
120
 
121
+ # 3. Diffusion Model (matches your original implementation)
122
  class DiffusionModel(nn.Module):
123
+ def __init__(self, model, timesteps=500, time_dim=256):
124
  super().__init__()
125
  self.model = model
126
  self.timesteps = timesteps
127
+ self.time_dim = time_dim
128
+
129
+ # Linear beta schedule (matches your original)
130
+ scale = 1000 / timesteps
131
+ beta_start = scale * 0.0001
132
+ beta_end = scale * 0.02
133
+ self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
134
  self.alphas = 1. - self.betas
135
+ self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0).float())
136
+
137
+ def forward_diffusion(self, x_0, t, noise):
138
+ x_0 = x_0.float()
139
+ noise = noise.float()
140
+ alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1)
141
+ x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1. - alpha_bar_t) * noise
142
+ return x_t
143
+
144
+ def forward(self, x_0, labels):
145
+ t = torch.randint(0, self.timesteps, (x_0.shape[0],), device=x_0.device).long()
146
+ noise = torch.randn_like(x_0)
147
+ x_t = self.forward_diffusion(x_0, t, noise)
148
+ predicted_noise = self.model(x_t, labels, t.float())
149
+ return predicted_noise, noise, t
150
 
151
  @torch.no_grad()
152
+ def sample(self, num_images, img_size, num_classes, labels, device):
153
+ # Matches your original sampling exactly
154
  x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
155
 
156
  if labels.ndim == 1:
 
161
  labels = labels.to(device)
162
 
163
  for t in reversed(range(self.timesteps)):
164
+ t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
 
 
 
 
165
  predicted_noise = self.model(x_t, labels, t_tensor)
166
 
167
  beta_t = self.betas[t].to(device)
 
177
  noise = torch.zeros_like(x_t)
178
 
179
  x_t = mean + torch.sqrt(variance) * noise
 
 
 
180
 
181
  x_0 = torch.clamp(x_t, -1., 1.)
182
 
183
+ # Normalization matching your original code
184
  mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
185
  std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
186
  x_0 = std * x_0 + mean
 
188
 
189
  return x_0
190
 
191
+ # 4. Model Loading (with improved error handling)
192
  def load_model(model_path, device):
193
+ unet_model = UNet(num_classes=NUM_CLASSES).to(device)
194
+ diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
195
+
196
  if os.path.exists(model_path):
197
+ checkpoint = torch.load(model_path, map_location=device)
198
+
199
+ if 'model_state_dict' in checkpoint:
200
+ # Filter out DiffusionModel-specific keys
201
+ state_dict = {
202
+ k[6:]: v for k, v in checkpoint['model_state_dict'].items()
203
+ if k.startswith('model.') and not k.startswith('model.alpha_bars')
204
+ }
205
 
206
+ # Load into UNet only
207
+ missing, unexpected = unet_model.load_state_dict(state_dict, strict=False)
208
 
209
+ print(f"Loaded UNet weights. Missing keys: {missing}. Unexpected keys: {unexpected}")
 
210
 
211
+ # Reinitialize diffusion model with loaded UNet
212
+ diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
213
+ else:
214
+ # Handle case where it's not a training checkpoint
215
+ diffusion_model.load_state_dict({
216
+ k: v for k, v in checkpoint.items()
217
+ if not k.startswith('alpha_bars')
218
+ })
219
 
220
+ print(f"Model successfully loaded from {model_path}")
 
 
221
  else:
222
+ print(f"Weights file not found at {model_path}")
223
+ print("Using randomly initialized weights")
224
+
225
  diffusion_model.eval()
226
  return diffusion_model
227
 
228
+ # 5. Gradio Interface (matches your original)
229
+ def generate_image(label_str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
231
  if label_str not in label_map:
232
+ raise gr.Error("Invalid label selected.")
233
+
234
+ label_index = label_map[label_str]
235
+ labels_to_generate = torch.zeros(1, 2).to(device)
236
+ labels_to_generate[:, label_index] = 1
237
+
238
+ generated_images_tensor = loaded_model.sample(
239
+ 1, IMG_SIZE, NUM_CLASSES, labels_to_generate, device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  )
241
 
242
+ img_np = generated_images_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
243
+ img_pil = Image.fromarray((img_np * 255).astype(np.uint8), 'RGB')
244
+ return img_pil
 
 
 
 
 
245
 
246
+ # Main Execution
247
  if __name__ == "__main__":
248
+ # Load model
249
+ model_path = "model_weights.pth" # Match your filename
250
+ loaded_model = load_model(model_path, device)
251
+
252
+ # Create interface
253
+ iface = gr.Interface(
254
+ fn=generate_image,
255
+ inputs=gr.Dropdown(["Pneumonia", "Pneumothorax"], label="Select Condition"),
256
+ outputs=gr.Image(type="pil", label="Generated X-ray Image"),
257
+ title="CheXpert X-ray Image Generator",
258
+ description="Generate synthetic chest X-ray images conditioned on selected conditions (Pneumonia or Pneumothorax) using a diffusion model."
259
+ )
260
+
261
+ iface.launch()