Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -27,13 +27,13 @@ class SinusoidalPositionEmbeddings(nn.Module):
|
|
27 |
self.dim = dim
|
28 |
half_dim = dim // 2
|
29 |
emb = math.log(10000) / (half_dim - 1)
|
30 |
-
emb = torch.exp(torch.arange(half_dim) * -emb)
|
31 |
self.register_buffer('embeddings', emb)
|
32 |
|
33 |
def forward(self, time):
|
34 |
device = time.device
|
35 |
embeddings = self.embeddings.to(device)
|
36 |
-
embeddings = time[:, None] * embeddings[None, :]
|
37 |
return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
|
38 |
|
39 |
class UNet(nn.Module):
|
@@ -133,15 +133,15 @@ class DiffusionModel(nn.Module):
|
|
133 |
scale = 1000 / timesteps
|
134 |
beta_start = scale * 0.0001
|
135 |
beta_end = scale * 0.02
|
136 |
-
self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.
|
137 |
self.alphas = 1. - self.betas
|
138 |
-
self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0)
|
139 |
|
140 |
@torch.no_grad()
|
141 |
def p_sample(self, x, t, labels):
|
142 |
-
betas_t = self.betas[t].view(-1, 1, 1, 1)
|
143 |
-
sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1. - self.alpha_bars[t]).view(-1, 1, 1, 1)
|
144 |
-
sqrt_recip_alphas_t = torch.sqrt(1.0 / (1. - self.betas[t])).view(-1, 1, 1, 1)
|
145 |
|
146 |
# Model prediction
|
147 |
pred_noise = self.model(x, labels, t.float())
|
@@ -154,11 +154,11 @@ class DiffusionModel(nn.Module):
|
|
154 |
else:
|
155 |
posterior_variance_t = self.betas[t] * (1. - self.alpha_bars[t-1]) / (1. - self.alpha_bars[t])
|
156 |
noise = torch.randn_like(x)
|
157 |
-
return model_mean + torch.sqrt(posterior_variance_t) * noise
|
158 |
|
159 |
@torch.no_grad()
|
160 |
def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
|
161 |
-
x = torch.randn((num_images, 3, img_size, img_size), device=device)
|
162 |
|
163 |
for i in reversed(range(0, self.timesteps)):
|
164 |
t = torch.full((num_images,), i, device=device, dtype=torch.long)
|
@@ -170,8 +170,8 @@ class DiffusionModel(nn.Module):
|
|
170 |
return None
|
171 |
|
172 |
x = torch.clamp(x, -1., 1.)
|
173 |
-
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
174 |
-
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
175 |
x = std * x + mean
|
176 |
x = torch.clamp(x, 0., 1.)
|
177 |
|
@@ -205,13 +205,14 @@ def load_model(model_path, device):
|
|
205 |
print("Model loaded successfully")
|
206 |
|
207 |
except Exception as e:
|
|
|
208 |
raise ValueError(f"Error loading model: {str(e)}")
|
209 |
|
210 |
diffusion_model = DiffusionModel(unet).to(device)
|
211 |
try:
|
212 |
diffusion_model = torch.compile(diffusion_model)
|
213 |
-
except:
|
214 |
-
print("Could not compile model - running uncompiled")
|
215 |
else:
|
216 |
raise FileNotFoundError(f"Model weights not found at {model_path}")
|
217 |
|
@@ -234,7 +235,7 @@ def generate_image(label_str, num_images, progress=gr.Progress()):
|
|
234 |
if label_str not in label_map:
|
235 |
raise gr.Error("Invalid condition selected")
|
236 |
|
237 |
-
labels = torch.zeros(num_images, NUM_CLASSES, device=device)
|
238 |
labels[:, label_map[label_str]] = 1
|
239 |
|
240 |
try:
|
@@ -271,7 +272,7 @@ def generate_image(label_str, num_images, progress=gr.Progress()):
|
|
271 |
torch.cuda.empty_cache()
|
272 |
raise gr.Error("Out of GPU memory - try generating fewer images")
|
273 |
except Exception as e:
|
274 |
-
|
275 |
if str(e) != "Generation was cancelled by user":
|
276 |
raise gr.Error(f"Generation failed: {str(e)}")
|
277 |
return None
|
@@ -280,7 +281,11 @@ def generate_image(label_str, num_images, progress=gr.Progress()):
|
|
280 |
|
281 |
# --- Load Model ---
|
282 |
model_path = "model_weights.pth"
|
283 |
-
|
|
|
|
|
|
|
|
|
284 |
|
285 |
# --- Gradio UI ---
|
286 |
with gr.Blocks(theme=gr.themes.Soft(
|
@@ -291,7 +296,7 @@ with gr.Blocks(theme=gr.themes.Soft(
|
|
291 |
)) as demo:
|
292 |
gr.Markdown("""
|
293 |
<center>
|
294 |
-
<h1>
|
295 |
<p><em>Generate synthetic chest X-rays conditioned on pathology</em></p>
|
296 |
</center>
|
297 |
""")
|
@@ -352,8 +357,11 @@ with gr.Blocks(theme=gr.themes.Soft(
|
|
352 |
"""
|
353 |
|
354 |
if __name__ == "__main__":
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
|
|
|
|
|
|
|
27 |
self.dim = dim
|
28 |
half_dim = dim // 2
|
29 |
emb = math.log(10000) / (half_dim - 1)
|
30 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
31 |
self.register_buffer('embeddings', emb)
|
32 |
|
33 |
def forward(self, time):
|
34 |
device = time.device
|
35 |
embeddings = self.embeddings.to(device)
|
36 |
+
embeddings = time.float()[:, None] * embeddings[None, :]
|
37 |
return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
|
38 |
|
39 |
class UNet(nn.Module):
|
|
|
133 |
scale = 1000 / timesteps
|
134 |
beta_start = scale * 0.0001
|
135 |
beta_end = scale * 0.02
|
136 |
+
self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
|
137 |
self.alphas = 1. - self.betas
|
138 |
+
self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
|
139 |
|
140 |
@torch.no_grad()
|
141 |
def p_sample(self, x, t, labels):
|
142 |
+
betas_t = self.betas[t].view(-1, 1, 1, 1).to(x.dtype).to(x.device)
|
143 |
+
sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1. - self.alpha_bars[t]).view(-1, 1, 1, 1).to(x.dtype).to(x.device)
|
144 |
+
sqrt_recip_alphas_t = torch.sqrt(1.0 / (1. - self.betas[t])).view(-1, 1, 1, 1).to(x.dtype).to(x.device)
|
145 |
|
146 |
# Model prediction
|
147 |
pred_noise = self.model(x, labels, t.float())
|
|
|
154 |
else:
|
155 |
posterior_variance_t = self.betas[t] * (1. - self.alpha_bars[t-1]) / (1. - self.alpha_bars[t])
|
156 |
noise = torch.randn_like(x)
|
157 |
+
return model_mean + torch.sqrt(posterior_variance_t).to(x.device) * noise
|
158 |
|
159 |
@torch.no_grad()
|
160 |
def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
|
161 |
+
x = torch.randn((num_images, 3, img_size, img_size), device=device, dtype=torch.float32)
|
162 |
|
163 |
for i in reversed(range(0, self.timesteps)):
|
164 |
t = torch.full((num_images,), i, device=device, dtype=torch.long)
|
|
|
170 |
return None
|
171 |
|
172 |
x = torch.clamp(x, -1., 1.)
|
173 |
+
mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(1, 3, 1, 1).to(device)
|
174 |
+
std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(1, 3, 1, 1).to(device)
|
175 |
x = std * x + mean
|
176 |
x = torch.clamp(x, 0., 1.)
|
177 |
|
|
|
205 |
print("Model loaded successfully")
|
206 |
|
207 |
except Exception as e:
|
208 |
+
traceback.print_exc()
|
209 |
raise ValueError(f"Error loading model: {str(e)}")
|
210 |
|
211 |
diffusion_model = DiffusionModel(unet).to(device)
|
212 |
try:
|
213 |
diffusion_model = torch.compile(diffusion_model)
|
214 |
+
except Exception as e:
|
215 |
+
print(f"Could not compile model - running uncompiled: {str(e)}")
|
216 |
else:
|
217 |
raise FileNotFoundError(f"Model weights not found at {model_path}")
|
218 |
|
|
|
235 |
if label_str not in label_map:
|
236 |
raise gr.Error("Invalid condition selected")
|
237 |
|
238 |
+
labels = torch.zeros(num_images, NUM_CLASSES, device=device, dtype=torch.float32)
|
239 |
labels[:, label_map[label_str]] = 1
|
240 |
|
241 |
try:
|
|
|
272 |
torch.cuda.empty_cache()
|
273 |
raise gr.Error("Out of GPU memory - try generating fewer images")
|
274 |
except Exception as e:
|
275 |
+
traceback.print_exc()
|
276 |
if str(e) != "Generation was cancelled by user":
|
277 |
raise gr.Error(f"Generation failed: {str(e)}")
|
278 |
return None
|
|
|
281 |
|
282 |
# --- Load Model ---
|
283 |
model_path = "model_weights.pth"
|
284 |
+
try:
|
285 |
+
loaded_model = load_model(model_path, device)
|
286 |
+
except Exception as e:
|
287 |
+
print(f"Failed to load model: {str(e)}")
|
288 |
+
raise
|
289 |
|
290 |
# --- Gradio UI ---
|
291 |
with gr.Blocks(theme=gr.themes.Soft(
|
|
|
296 |
)) as demo:
|
297 |
gr.Markdown("""
|
298 |
<center>
|
299 |
+
<h1>Synthetic X-ray Generator</h1>
|
300 |
<p><em>Generate synthetic chest X-rays conditioned on pathology</em></p>
|
301 |
</center>
|
302 |
""")
|
|
|
357 |
"""
|
358 |
|
359 |
if __name__ == "__main__":
|
360 |
+
try:
|
361 |
+
demo.launch(
|
362 |
+
server_name="0.0.0.0",
|
363 |
+
server_port=7860,
|
364 |
+
share=False
|
365 |
+
)
|
366 |
+
except Exception as e:
|
367 |
+
print(f"Failed to launch app: {str(e)}")
|