Vedansh-7 commited on
Commit
7531575
·
1 Parent(s): 50426d9

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -367
app.py DELETED
@@ -1,367 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import gradio as gr
4
- from PIL import Image
5
- import numpy as np
6
- import math
7
- import os
8
- from threading import Event
9
- import traceback
10
-
11
- # --- Constants ---
12
- IMG_SIZE = 128
13
- TRAINING_TIMESTEPS = 300
14
- INFERENCE_TIMESTEPS = 300
15
- NUM_CLASSES = 2
16
-
17
- # --- Global Cancellation Flag ---
18
- cancel_event = Event()
19
-
20
- # --- Device Configuration ---
21
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
-
23
- # --- Model Definitions ---
24
- class SinusoidalPositionEmbeddings(nn.Module):
25
- def __init__(self, dim):
26
- super().__init__()
27
- self.dim = dim
28
- half_dim = dim // 2
29
- emb = math.log(10000) / (half_dim - 1)
30
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
31
- self.register_buffer('embeddings', emb)
32
-
33
- def forward(self, time):
34
- device = time.device
35
- embeddings = self.embeddings.to(device)
36
- embeddings = time.float()[:, None] * embeddings[None, :]
37
- return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
38
-
39
- class UNet(nn.Module):
40
- def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
41
- super().__init__()
42
- self.num_classes = num_classes
43
- self.label_embedding = nn.Embedding(num_classes, time_dim)
44
-
45
- self.time_mlp = nn.Sequential(
46
- SinusoidalPositionEmbeddings(time_dim),
47
- nn.Linear(time_dim, time_dim),
48
- nn.ReLU(),
49
- nn.Linear(time_dim, time_dim)
50
- )
51
-
52
- self.inc = self.double_conv(in_channels, 64)
53
- self.down1 = self.down(64 + time_dim * 2, 128)
54
- self.down2 = self.down(128 + time_dim * 2, 256)
55
- self.down3 = self.down(256 + time_dim * 2, 512)
56
-
57
- self.bottleneck = self.double_conv(512 + time_dim * 2, 1024)
58
-
59
- self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2)
60
- self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256)
61
-
62
- self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
63
- self.upconv2 = self.double_conv(128 + 128 + time_dim * 2, 128)
64
-
65
- self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
66
- self.upconv3 = self.double_conv(64 + 64 + time_dim * 2, 64)
67
-
68
- self.outc = nn.Conv2d(64, out_channels, kernel_size=1)
69
-
70
- def double_conv(self, in_channels, out_channels):
71
- return nn.Sequential(
72
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
73
- nn.ReLU(inplace=True),
74
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
75
- nn.ReLU(inplace=True)
76
- )
77
-
78
- def down(self, in_channels, out_channels):
79
- return nn.Sequential(
80
- nn.MaxPool2d(2),
81
- self.double_conv(in_channels, out_channels)
82
- )
83
-
84
- def forward(self, x, labels, time):
85
- label_indices = torch.argmax(labels, dim=1)
86
- label_emb = self.label_embedding(label_indices)
87
- t_emb = self.time_mlp(time)
88
-
89
- combined_emb = torch.cat([t_emb, label_emb], dim=1)
90
- combined_emb = combined_emb.unsqueeze(-1).unsqueeze(-1)
91
-
92
- x1 = self.inc(x)
93
- x1_cat = torch.cat([x1, combined_emb.repeat(1, 1, x1.shape[-2], x1.shape[-1])], dim=1)
94
-
95
- x2 = self.down1(x1_cat)
96
- x2_cat = torch.cat([x2, combined_emb.repeat(1, 1, x2.shape[-2], x2.shape[-1])], dim=1)
97
-
98
- x3 = self.down2(x2_cat)
99
- x3_cat = torch.cat([x3, combined_emb.repeat(1, 1, x3.shape[-2], x3.shape[-1])], dim=1)
100
-
101
- x4 = self.down3(x3_cat)
102
- x4_cat = torch.cat([x4, combined_emb.repeat(1, 1, x4.shape[-2], x4.shape[-1])], dim=1)
103
-
104
- x5 = self.bottleneck(x4_cat)
105
-
106
- x = self.up1(x5)
107
- x = torch.cat([x, x3], dim=1)
108
- x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
109
- x = self.upconv1(x)
110
-
111
- x = self.up2(x)
112
- x = torch.cat([x, x2], dim=1)
113
- x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
114
- x = self.upconv2(x)
115
-
116
- x = self.up3(x)
117
- x = torch.cat([x, x1], dim=1)
118
- x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
119
- x = self.upconv3(x)
120
-
121
- output = self.outc(x)
122
- return output
123
-
124
-
125
- class DiffusionModel(nn.Module):
126
- def __init__(self, model, timesteps=TRAINING_TIMESTEPS, time_dim=256):
127
- super().__init__()
128
- self.model = model
129
- self.timesteps = timesteps
130
- self.time_dim = time_dim
131
-
132
- # Linear beta schedule (matches original implementation)
133
- scale = 1000 / timesteps
134
- beta_start = scale * 0.0001
135
- beta_end = scale * 0.02
136
- self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
137
- self.alphas = 1. - self.betas
138
- self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
139
-
140
- @torch.no_grad()
141
- def p_sample(self, x, t, labels):
142
- betas_t = self.betas[t].view(-1, 1, 1, 1).to(x.dtype).to(x.device)
143
- sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1. - self.alpha_bars[t]).view(-1, 1, 1, 1).to(x.dtype).to(x.device)
144
- sqrt_recip_alphas_t = torch.sqrt(1.0 / (1. - self.betas[t])).view(-1, 1, 1, 1).to(x.dtype).to(x.device)
145
-
146
- # Model prediction
147
- pred_noise = self.model(x, labels, t.float())
148
-
149
- # Direction pointing to x_t
150
- model_mean = sqrt_recip_alphas_t * (x - betas_t * pred_noise / sqrt_one_minus_alphas_cumprod_t)
151
-
152
- if t == 0:
153
- return model_mean
154
- else:
155
- posterior_variance_t = self.betas[t] * (1. - self.alpha_bars[t-1]) / (1. - self.alpha_bars[t])
156
- noise = torch.randn_like(x)
157
- return model_mean + torch.sqrt(posterior_variance_t).to(x.device) * noise
158
-
159
- @torch.no_grad()
160
- def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
161
- x = torch.randn((num_images, 3, img_size, img_size), device=device, dtype=torch.float32)
162
-
163
- for i in reversed(range(0, self.timesteps)):
164
- t = torch.full((num_images,), i, device=device, dtype=torch.long)
165
- x = self.p_sample(x, t, labels)
166
-
167
- if progress_callback:
168
- progress_callback((self.timesteps - i) / self.timesteps)
169
- if cancel_event.is_set():
170
- return None
171
-
172
- x = torch.clamp(x, -1., 1.)
173
- mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(1, 3, 1, 1).to(device)
174
- std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(1, 3, 1, 1).to(device)
175
- x = std * x + mean
176
- x = torch.clamp(x, 0., 1.)
177
-
178
- return x
179
-
180
- def load_model(model_path, device):
181
- unet = UNet(num_classes=NUM_CLASSES).to(device)
182
-
183
- if os.path.exists(model_path):
184
- try:
185
- checkpoint = torch.load(model_path, map_location=device)
186
-
187
- # More flexible state dict loading
188
- if 'model_state_dict' in checkpoint:
189
- state_dict = checkpoint['model_state_dict']
190
- else:
191
- state_dict = checkpoint
192
-
193
- # Handle both prefixed and non-prefixed state dicts
194
- if all(k.startswith('model.') for k in state_dict.keys()):
195
- state_dict = {k[6:]: v for k, v in state_dict.items()}
196
-
197
- # Load with error information
198
- missing_keys, unexpected_keys = unet.load_state_dict(state_dict, strict=False)
199
-
200
- if missing_keys:
201
- print(f"Missing keys in state dict: {missing_keys}")
202
- if unexpected_keys:
203
- print(f"Unexpected keys in state dict: {unexpected_keys}")
204
-
205
- print("Model loaded successfully")
206
-
207
- except Exception as e:
208
- traceback.print_exc()
209
- raise ValueError(f"Error loading model: {str(e)}")
210
-
211
- diffusion_model = DiffusionModel(unet).to(device)
212
- try:
213
- diffusion_model = torch.compile(diffusion_model)
214
- except Exception as e:
215
- print(f"Could not compile model - running uncompiled: {str(e)}")
216
- else:
217
- raise FileNotFoundError(f"Model weights not found at {model_path}")
218
-
219
- diffusion_model.eval()
220
- return diffusion_model
221
-
222
- def cancel_generation():
223
- cancel_event.set()
224
- return "Generation cancelled"
225
-
226
- def generate_image(label_str, num_images, progress=gr.Progress()):
227
- global loaded_model
228
- cancel_event.clear()
229
-
230
- # Input validation
231
- if num_images < 1 or num_images > 10:
232
- raise gr.Error("Number of images must be between 1 and 10")
233
-
234
- label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
235
- if label_str not in label_map:
236
- raise gr.Error("Invalid condition selected")
237
-
238
- labels = torch.zeros(num_images, NUM_CLASSES, device=device, dtype=torch.float32)
239
- labels[:, label_map[label_str]] = 1
240
-
241
- try:
242
- def progress_callback(progress_val):
243
- progress(progress_val, desc="Generating...")
244
- if cancel_event.is_set():
245
- raise gr.Error("Generation was cancelled by user")
246
-
247
- with torch.no_grad(), torch.cuda.amp.autocast():
248
- images = loaded_model.sample(
249
- num_images=num_images,
250
- img_size=IMG_SIZE,
251
- num_classes=NUM_CLASSES,
252
- labels=labels,
253
- device=device,
254
- progress_callback=progress_callback
255
- )
256
-
257
- if images is None:
258
- return None
259
-
260
- processed_images = []
261
- for img in images:
262
- # Convert tensor to numpy array with proper scaling
263
- img_np = img.mul(255).clamp(0, 255).byte().cpu().numpy()
264
- img_np = img_np.transpose(1, 2, 0) # CHW to HWC
265
- pil_img = Image.fromarray(img_np, 'RGB')
266
-
267
- processed_images.append(pil_img)
268
-
269
- return processed_images
270
-
271
- except torch.cuda.OutOfMemoryError:
272
- torch.cuda.empty_cache()
273
- raise gr.Error("Out of GPU memory - try generating fewer images")
274
- except Exception as e:
275
- traceback.print_exc()
276
- if str(e) != "Generation was cancelled by user":
277
- raise gr.Error(f"Generation failed: {str(e)}")
278
- return None
279
- finally:
280
- torch.cuda.empty_cache()
281
-
282
- # --- Load Model ---
283
- model_path = "model_weights.pth"
284
- try:
285
- loaded_model = load_model(model_path, device)
286
- except Exception as e:
287
- print(f"Failed to load model: {str(e)}")
288
- raise
289
-
290
- # --- Gradio UI ---
291
- with gr.Blocks(theme=gr.themes.Soft(
292
- primary_hue="violet",
293
- neutral_hue="slate",
294
- font=[gr.themes.GoogleFont("Poppins")],
295
- text_size="md"
296
- )) as demo:
297
- gr.Markdown("""
298
- <center>
299
- <h1>Synthetic X-ray Generator</h1>
300
- <p><em>Generate synthetic chest X-rays conditioned on pathology</em></p>
301
- </center>
302
- """)
303
-
304
- with gr.Row():
305
- with gr.Column(scale=1):
306
- condition = gr.Dropdown(
307
- ["Pneumonia", "Pneumothorax"],
308
- label="Select Condition",
309
- value="Pneumonia",
310
- interactive=True
311
- )
312
- num_images = gr.Slider(
313
- 1, 10, value=1, step=1,
314
- label="Number of Images",
315
- interactive=True
316
- )
317
-
318
- with gr.Row():
319
- submit_btn = gr.Button("Generate", variant="primary")
320
- cancel_btn = gr.Button("Cancel", variant="stop")
321
-
322
- gr.Markdown("""
323
- <div style="text-align: center; margin-top: 10px;">
324
- <small>Note: Generation may take several seconds per image</small>
325
- </div>
326
- """)
327
-
328
- with gr.Column(scale=2):
329
- gallery = gr.Gallery(
330
- label="Generated X-rays",
331
- columns=3,
332
- height="auto",
333
- object_fit="contain",
334
- preview=True
335
- )
336
-
337
- submit_btn.click(
338
- fn=generate_image,
339
- inputs=[condition, num_images],
340
- outputs=gallery,
341
- api_name="generate"
342
- )
343
-
344
- cancel_btn.click(
345
- fn=cancel_generation,
346
- outputs=None,
347
- api_name="cancel"
348
- )
349
-
350
- demo.css = """
351
- .gradio-container {
352
- background: linear-gradient(135deg, #f5f7fa 0%, #e4e8f0 100%);
353
- }
354
- .gallery-container {
355
- background-color: white !important;
356
- }
357
- """
358
-
359
- if __name__ == "__main__":
360
- try:
361
- demo.launch(
362
- server_name="0.0.0.0",
363
- server_port=7860,
364
- share=False
365
- )
366
- except Exception as e:
367
- print(f"Failed to launch app: {str(e)}")