Vedansh-7's picture
Update app.py
a1d7fd6 verified
raw
history blame
13.7 kB
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
import numpy as np
import math
import os
from threading import Event
import traceback
# Constants
IMG_SIZE = 128
TIMESTEPS = 300
NUM_CLASSES = 2
# Global Cancellation Flag
cancel_event = Event()
# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Model Definitions (from second file) ---
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.register_buffer('embeddings', self._precompute_embeddings(dim))
def _precompute_embeddings(self, dim):
half_dim = dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim) * -emb)
return emb
def forward(self, time):
device = time.device
embeddings = self.embeddings.to(device)
embeddings = time[:, None] * embeddings[None, :]
output = torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
return output
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
super().__init__()
self.num_classes = num_classes
self.label_embedding = nn.Embedding(num_classes, time_dim)
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(time_dim),
nn.Linear(time_dim, time_dim),
nn.ReLU(),
nn.Linear(time_dim, time_dim)
)
self.inc = self.double_conv(in_channels, 64)
self.down1 = self.down(64 + time_dim * 2, 128)
self.down2 = self.down(128 + time_dim * 2, 256)
self.down3 = self.down(256 + time_dim * 2, 512)
self.bottleneck = self.double_conv(512 + time_dim * 2, 1024)
self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2)
self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256)
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.upconv2 = self.double_conv(128 + 128 + time_dim * 2, 128)
self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.upconv3 = self.double_conv(64 + 64 + time_dim * 2, 64)
self.outc = nn.Conv2d(64, out_channels, kernel_size=1)
def double_conv(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def down(self, in_channels, out_channels):
return nn.Sequential(
nn.MaxPool2d(2),
self.double_conv(in_channels, out_channels)
)
def forward(self, x, labels, time):
label_indices = torch.argmax(labels, dim=1)
label_emb = self.label_embedding(label_indices)
t_emb = self.time_mlp(time)
combined_emb = torch.cat([t_emb, label_emb], dim=1)
combined_emb = combined_emb.unsqueeze(-1).unsqueeze(-1)
x1 = self.inc(x)
x1_cat = torch.cat([x1, combined_emb.repeat(1, 1, x1.shape[-2], x1.shape[-1])], dim=1)
x2 = self.down1(x1_cat)
x2_cat = torch.cat([x2, combined_emb.repeat(1, 1, x2.shape[-2], x2.shape[-1])], dim=1)
x3 = self.down2(x2_cat)
x3_cat = torch.cat([x3, combined_emb.repeat(1, 1, x3.shape[-2], x3.shape[-1])], dim=1)
x4 = self.down3(x3_cat)
x4_cat = torch.cat([x4, combined_emb.repeat(1, 1, x4.shape[-2], x4.shape[-1])], dim=1)
x5 = self.bottleneck(x4_cat)
x = self.up1(x5)
x = torch.cat([x, x3], dim=1)
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
x = self.upconv1(x)
x = self.up2(x)
x = torch.cat([x, x2], dim=1)
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
x = self.upconv2(x)
x = self.up3(x)
x = torch.cat([x, x1], dim=1)
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
x = self.upconv3(x)
output = self.outc(x)
return output
class DiffusionModel(nn.Module):
def __init__(self, model, timesteps=500, time_dim=256):
super().__init__()
self.model = model
self.timesteps = timesteps
self.time_dim = time_dim
self.betas = self.linear_schedule(timesteps)
self.alphas = 1. - self.betas
self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0).float())
def linear_schedule(self, timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
def forward_diffusion(self, x_0, t, noise):
x_0 = x_0.float()
noise = noise.float()
alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1)
x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1. - alpha_bar_t) * noise
return x_t
def forward(self, x_0, labels):
t = torch.randint(0, self.timesteps, (x_0.shape[0],), device=x_0.device).long()
noise = torch.randn_like(x_0)
x_t = self.forward_diffusion(x_0, t, noise)
predicted_noise = self.model(x_t, labels, t.float())
return predicted_noise, noise, t
@torch.no_grad()
def sample(model, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
if labels.ndim == 1:
labels_one_hot = torch.zeros(num_images, num_classes).to(device)
labels_one_hot[torch.arange(num_images), labels] = 1
labels = labels_one_hot
else:
labels = labels.to(device)
for t in reversed(range(timesteps)):
if cancel_event.is_set():
return None
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
predicted_noise = model.model(x_t, labels, t_tensor)
beta_t = model.betas[t].to(device)
alpha_t = model.alphas[t].to(device)
alpha_bar_t = model.alpha_bars[t].to(device)
mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
variance = beta_t
if t > 0:
noise = torch.randn_like(x_t)
else:
noise = torch.zeros_like(x_t)
x_t = mean + torch.sqrt(variance) * noise
if progress_callback:
progress_callback((timesteps - t) / timesteps)
x_0 = torch.clamp(x_t, -1., 1.)
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
x_0 = std * x_0 + mean
x_0 = torch.clamp(x_0, 0., 1.)
return x_0
def load_model(model_path, device):
unet_model = UNet(num_classes=NUM_CLASSES).to(device)
diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
try:
checkpoint = torch.load(model_path, map_location=device)
if 'model_state_dict' in checkpoint:
diffusion_model.model.load_state_dict(checkpoint['model_state_dict'])
else:
diffusion_model.model.load_state_dict(checkpoint)
print(f"Successfully loaded model from {model_path}")
except Exception as e:
print(f"Error loading model: {e}")
print("Using randomly initialized weights")
diffusion_model.eval()
return diffusion_model
def cancel_generation():
cancel_event.set()
return "Generation cancelled"
def generate_single_image(label_str):
label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
try:
label_index = label_map[label_str]
except KeyError:
raise gr.Error(f"Invalid label '{label_str}'. Please select either 'Pneumonia' or 'Pneumothorax'.")
labels = torch.zeros(1, NUM_CLASSES, device=device)
labels[0, label_index] = 1
with torch.no_grad():
generated_image = sample(
model=loaded_model,
num_images=1,
timesteps=TIMESTEPS,
img_size=IMG_SIZE,
num_classes=NUM_CLASSES,
labels=labels,
device=device
)
img_np = generated_image.squeeze(0).cpu().permute(1, 2, 0).numpy()
img_np = np.clip(img_np, 0, 1)
img_pil = Image.fromarray((img_np * 255).astype(np.uint8))
return img_pil
def generate_images(label_str, num_images, progress=gr.Progress()):
global loaded_model
cancel_event.clear()
# Input validation
if num_images < 1 or num_images > 10:
raise gr.Error("Number of images must be between 1 and 10")
label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
if label_str not in label_map:
raise gr.Error("Invalid condition selected")
labels = torch.zeros(num_images, NUM_CLASSES, device=device)
labels[:, label_map[label_str]] = 1
try:
def progress_callback(progress_val):
progress(progress_val, desc="Generating...")
if cancel_event.is_set():
raise gr.Error("Generation was cancelled by user")
with torch.no_grad():
images = sample(
model=loaded_model,
num_images=num_images,
timesteps=TIMESTEPS,
img_size=IMG_SIZE,
num_classes=NUM_CLASSES,
labels=labels,
device=device,
progress_callback=progress_callback
)
if images is None:
return None, None
# Process all generated images
processed_images = []
for img in images:
img_np = img.cpu().permute(1, 2, 0).numpy()
img_np = np.clip(img_np, 0, 1)
pil_img = Image.fromarray((img_np * 255).astype(np.uint8))
processed_images.append(pil_img)
# Return both single image and gallery based on count
if num_images == 1:
return processed_images[0], processed_images
else:
return None, processed_images
except torch.cuda.OutOfMemoryError:
torch.cuda.empty_cache()
raise gr.Error("Out of GPU memory - try generating fewer images")
except Exception as e:
traceback.print_exc()
if str(e) != "Generation was cancelled by user":
raise gr.Error(f"Generation failed: {str(e)}")
return None, None
finally:
torch.cuda.empty_cache()
# Load model
MODEL_DIR = "models"
MODEL_NAME = "diffusion_unet_xray.pth"
model_path = os.path.join(MODEL_DIR, MODEL_NAME)
print("Loading model...")
loaded_model = load_model(model_path, device)
print("Model loaded successfully!")
# Unified Gradio UI
with gr.Blocks(theme=gr.themes.Soft(
primary_hue="violet",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Poppins")],
text_size="md"
)) as demo:
gr.Markdown("""
<center>
<h1>Synthetic X-ray Generator</h1>
<p><em>Generate synthetic chest X-rays conditioned on pathology</em></p>
</center>
""")
with gr.Row():
with gr.Column(scale=1):
condition = gr.Dropdown(
["Pneumonia", "Pneumothorax"],
label="Select Condition",
value="Pneumonia",
interactive=True
)
num_images = gr.Slider(
1, 10, value=1, step=1,
label="Number of Images",
interactive=True
)
with gr.Row():
submit_btn = gr.Button("Generate", variant="primary")
cancel_btn = gr.Button("Cancel", variant="stop")
gr.Markdown("""
<div style="text-align: center; margin-top: 10px;">
<small>Note: Generation may take several seconds per image</small>
</div>
""")
with gr.Column(scale=2):
# Unified output display that adapts to single/batch
with gr.Tabs():
with gr.TabItem("Output", id="output_tab"):
single_image = gr.Image(
label="Generated X-ray",
height=400,
visible=True
)
gallery = gr.Gallery(
label="Generated X-rays",
columns=3,
height="auto",
object_fit="contain",
visible=False
)
def update_ui_based_on_count(num_images):
if num_images == 1:
return {
single_image: gr.update(visible=True),
gallery: gr.update(visible=False)
}
else:
return {
single_image: gr.update(visible=False),
gallery: gr.update(visible=True)
}
num_images.change(
fn=update_ui_based_on_count,
inputs=num_images,
outputs=[single_image, gallery]
)
submit_btn.click(
fn=generate_images,
inputs=[condition, num_images],
outputs=[single_image, gallery]
)
cancel_btn.click(
fn=cancel_generation,
outputs=None
)
demo.css = """
.gradio-container {
background: linear-gradient(135deg, #f5f7fa 0%, #e4e8f0 100%);
}
.gallery-container {
background-color: white !important;
}
"""
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)