Vedansh-7 commited on
Commit
50426d9
·
1 Parent(s): 4a59445

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -22
app.py CHANGED
@@ -27,13 +27,13 @@ class SinusoidalPositionEmbeddings(nn.Module):
27
  self.dim = dim
28
  half_dim = dim // 2
29
  emb = math.log(10000) / (half_dim - 1)
30
- emb = torch.exp(torch.arange(half_dim) * -emb)
31
  self.register_buffer('embeddings', emb)
32
 
33
  def forward(self, time):
34
  device = time.device
35
  embeddings = self.embeddings.to(device)
36
- embeddings = time[:, None] * embeddings[None, :]
37
  return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
38
 
39
  class UNet(nn.Module):
@@ -133,15 +133,15 @@ class DiffusionModel(nn.Module):
133
  scale = 1000 / timesteps
134
  beta_start = scale * 0.0001
135
  beta_end = scale * 0.02
136
- self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
137
  self.alphas = 1. - self.betas
138
- self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0).float())
139
 
140
  @torch.no_grad()
141
  def p_sample(self, x, t, labels):
142
- betas_t = self.betas[t].view(-1, 1, 1, 1)
143
- sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1. - self.alpha_bars[t]).view(-1, 1, 1, 1)
144
- sqrt_recip_alphas_t = torch.sqrt(1.0 / (1. - self.betas[t])).view(-1, 1, 1, 1)
145
 
146
  # Model prediction
147
  pred_noise = self.model(x, labels, t.float())
@@ -154,11 +154,11 @@ class DiffusionModel(nn.Module):
154
  else:
155
  posterior_variance_t = self.betas[t] * (1. - self.alpha_bars[t-1]) / (1. - self.alpha_bars[t])
156
  noise = torch.randn_like(x)
157
- return model_mean + torch.sqrt(posterior_variance_t) * noise
158
 
159
  @torch.no_grad()
160
  def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
161
- x = torch.randn((num_images, 3, img_size, img_size), device=device)
162
 
163
  for i in reversed(range(0, self.timesteps)):
164
  t = torch.full((num_images,), i, device=device, dtype=torch.long)
@@ -170,8 +170,8 @@ class DiffusionModel(nn.Module):
170
  return None
171
 
172
  x = torch.clamp(x, -1., 1.)
173
- mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
174
- std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
175
  x = std * x + mean
176
  x = torch.clamp(x, 0., 1.)
177
 
@@ -205,13 +205,14 @@ def load_model(model_path, device):
205
  print("Model loaded successfully")
206
 
207
  except Exception as e:
 
208
  raise ValueError(f"Error loading model: {str(e)}")
209
 
210
  diffusion_model = DiffusionModel(unet).to(device)
211
  try:
212
  diffusion_model = torch.compile(diffusion_model)
213
- except:
214
- print("Could not compile model - running uncompiled")
215
  else:
216
  raise FileNotFoundError(f"Model weights not found at {model_path}")
217
 
@@ -234,7 +235,7 @@ def generate_image(label_str, num_images, progress=gr.Progress()):
234
  if label_str not in label_map:
235
  raise gr.Error("Invalid condition selected")
236
 
237
- labels = torch.zeros(num_images, NUM_CLASSES, device=device)
238
  labels[:, label_map[label_str]] = 1
239
 
240
  try:
@@ -271,7 +272,7 @@ def generate_image(label_str, num_images, progress=gr.Progress()):
271
  torch.cuda.empty_cache()
272
  raise gr.Error("Out of GPU memory - try generating fewer images")
273
  except Exception as e:
274
- print(f"Full error: {traceback.format_exc()}")
275
  if str(e) != "Generation was cancelled by user":
276
  raise gr.Error(f"Generation failed: {str(e)}")
277
  return None
@@ -280,7 +281,11 @@ def generate_image(label_str, num_images, progress=gr.Progress()):
280
 
281
  # --- Load Model ---
282
  model_path = "model_weights.pth"
283
- loaded_model = load_model(model_path, device)
 
 
 
 
284
 
285
  # --- Gradio UI ---
286
  with gr.Blocks(theme=gr.themes.Soft(
@@ -291,7 +296,7 @@ with gr.Blocks(theme=gr.themes.Soft(
291
  )) as demo:
292
  gr.Markdown("""
293
  <center>
294
- <h1>CheXpert X-ray Image Generator</h1>
295
  <p><em>Generate synthetic chest X-rays conditioned on pathology</em></p>
296
  </center>
297
  """)
@@ -352,8 +357,11 @@ with gr.Blocks(theme=gr.themes.Soft(
352
  """
353
 
354
  if __name__ == "__main__":
355
- demo.launch(
356
- server_name="0.0.0.0",
357
- server_port=7860,
358
- share=False
359
- )
 
 
 
 
27
  self.dim = dim
28
  half_dim = dim // 2
29
  emb = math.log(10000) / (half_dim - 1)
30
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
31
  self.register_buffer('embeddings', emb)
32
 
33
  def forward(self, time):
34
  device = time.device
35
  embeddings = self.embeddings.to(device)
36
+ embeddings = time.float()[:, None] * embeddings[None, :]
37
  return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
38
 
39
  class UNet(nn.Module):
 
133
  scale = 1000 / timesteps
134
  beta_start = scale * 0.0001
135
  beta_end = scale * 0.02
136
+ self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
137
  self.alphas = 1. - self.betas
138
+ self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
139
 
140
  @torch.no_grad()
141
  def p_sample(self, x, t, labels):
142
+ betas_t = self.betas[t].view(-1, 1, 1, 1).to(x.dtype).to(x.device)
143
+ sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1. - self.alpha_bars[t]).view(-1, 1, 1, 1).to(x.dtype).to(x.device)
144
+ sqrt_recip_alphas_t = torch.sqrt(1.0 / (1. - self.betas[t])).view(-1, 1, 1, 1).to(x.dtype).to(x.device)
145
 
146
  # Model prediction
147
  pred_noise = self.model(x, labels, t.float())
 
154
  else:
155
  posterior_variance_t = self.betas[t] * (1. - self.alpha_bars[t-1]) / (1. - self.alpha_bars[t])
156
  noise = torch.randn_like(x)
157
+ return model_mean + torch.sqrt(posterior_variance_t).to(x.device) * noise
158
 
159
  @torch.no_grad()
160
  def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
161
+ x = torch.randn((num_images, 3, img_size, img_size), device=device, dtype=torch.float32)
162
 
163
  for i in reversed(range(0, self.timesteps)):
164
  t = torch.full((num_images,), i, device=device, dtype=torch.long)
 
170
  return None
171
 
172
  x = torch.clamp(x, -1., 1.)
173
+ mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(1, 3, 1, 1).to(device)
174
+ std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(1, 3, 1, 1).to(device)
175
  x = std * x + mean
176
  x = torch.clamp(x, 0., 1.)
177
 
 
205
  print("Model loaded successfully")
206
 
207
  except Exception as e:
208
+ traceback.print_exc()
209
  raise ValueError(f"Error loading model: {str(e)}")
210
 
211
  diffusion_model = DiffusionModel(unet).to(device)
212
  try:
213
  diffusion_model = torch.compile(diffusion_model)
214
+ except Exception as e:
215
+ print(f"Could not compile model - running uncompiled: {str(e)}")
216
  else:
217
  raise FileNotFoundError(f"Model weights not found at {model_path}")
218
 
 
235
  if label_str not in label_map:
236
  raise gr.Error("Invalid condition selected")
237
 
238
+ labels = torch.zeros(num_images, NUM_CLASSES, device=device, dtype=torch.float32)
239
  labels[:, label_map[label_str]] = 1
240
 
241
  try:
 
272
  torch.cuda.empty_cache()
273
  raise gr.Error("Out of GPU memory - try generating fewer images")
274
  except Exception as e:
275
+ traceback.print_exc()
276
  if str(e) != "Generation was cancelled by user":
277
  raise gr.Error(f"Generation failed: {str(e)}")
278
  return None
 
281
 
282
  # --- Load Model ---
283
  model_path = "model_weights.pth"
284
+ try:
285
+ loaded_model = load_model(model_path, device)
286
+ except Exception as e:
287
+ print(f"Failed to load model: {str(e)}")
288
+ raise
289
 
290
  # --- Gradio UI ---
291
  with gr.Blocks(theme=gr.themes.Soft(
 
296
  )) as demo:
297
  gr.Markdown("""
298
  <center>
299
+ <h1>Synthetic X-ray Generator</h1>
300
  <p><em>Generate synthetic chest X-rays conditioned on pathology</em></p>
301
  </center>
302
  """)
 
357
  """
358
 
359
  if __name__ == "__main__":
360
+ try:
361
+ demo.launch(
362
+ server_name="0.0.0.0",
363
+ server_port=7860,
364
+ share=False
365
+ )
366
+ except Exception as e:
367
+ print(f"Failed to launch app: {str(e)}")