Vedansh-7's picture
Update app.py
dd9af11
raw
history blame
14.5 kB
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
import numpy as np
import math
import os
from threading import Event
import traceback
# Constants
IMG_SIZE = 128
TIMESTEPS = 500
NUM_CLASSES = 2
# Global Cancellation Flag
cancel_event = Event()
# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Model Definitions ---
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
half_dim = dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
self.register_buffer('embeddings', emb)
def forward(self, time):
embeddings = self.embeddings.to(time.device)
embeddings = time.float()[:, None] * embeddings[None, :]
return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
super().__init__()
self.num_classes = num_classes
self.label_embedding = nn.Embedding(num_classes, time_dim)
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(time_dim),
nn.Linear(time_dim, time_dim),
nn.ReLU(),
nn.Linear(time_dim, time_dim)
)
self.inc = self.double_conv(in_channels, 64)
self.down1 = self.down(64 + time_dim * 2, 128)
self.down2 = self.down(128 + time_dim * 2, 256)
self.down3 = self.down(256 + time_dim * 2, 512)
self.bottleneck = self.double_conv(512 + time_dim * 2, 1024)
self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2)
self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256)
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.upconv2 = self.double_conv(128 + 128 + time_dim * 2, 128)
self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.upconv3 = self.double_conv(64 + 64 + time_dim * 2, 64)
self.outc = nn.Conv2d(64, out_channels, kernel_size=1)
def double_conv(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def down(self, in_channels, out_channels):
return nn.Sequential(
nn.MaxPool2d(2),
self.double_conv(in_channels, out_channels)
)
def forward(self, x, labels, time):
label_indices = torch.argmax(labels, dim=1)
label_emb = self.label_embedding(label_indices)
t_emb = self.time_mlp(time)
combined_emb = torch.cat([t_emb, label_emb], dim=1)
combined_emb = combined_emb.unsqueeze(-1).unsqueeze(-1)
x1 = self.inc(x)
x1_cat = torch.cat([x1, combined_emb.repeat(1, 1, x1.shape[-2], x1.shape[-1])], dim=1)
x2 = self.down1(x1_cat)
x2_cat = torch.cat([x2, combined_emb.repeat(1, 1, x2.shape[-2], x2.shape[-1])], dim=1)
x3 = self.down2(x2_cat)
x3_cat = torch.cat([x3, combined_emb.repeat(1, 1, x3.shape[-2], x3.shape[-1])], dim=1)
x4 = self.down3(x3_cat)
x4_cat = torch.cat([x4, combined_emb.repeat(1, 1, x4.shape[-2], x4.shape[-1])], dim=1)
x5 = self.bottleneck(x4_cat)
x = self.up1(x5)
x = torch.cat([x, x3], dim=1)
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
x = self.upconv1(x)
x = self.up2(x)
x = torch.cat([x, x2], dim=1)
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
x = self.upconv2(x)
x = self.up3(x)
x = torch.cat([x, x1], dim=1)
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1)
x = self.upconv3(x)
output = self.outc(x)
return output
class DiffusionModel(nn.Module):
def __init__(self, model, timesteps=TIMESTEPS, time_dim=256):
super().__init__()
self.model = model
self.timesteps = timesteps
# More conservative noise schedule
scale = 1000 / timesteps
beta_start = 0.0001
beta_end = 0.02
self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)**1.5
self.alphas = 1. - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
@torch.no_grad()
def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
# Initialize with reduced noise scale
x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.7
# Convert labels if needed
if labels.ndim == 1:
labels_one_hot = torch.zeros(num_images, num_classes, device=device)
labels_one_hot[torch.arange(num_images), labels] = 1
labels = labels_one_hot
for t in reversed(range(timesteps)):
if cancel_event.is_set():
return None
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
# Predict noise with model
pred_noise = self.model(x_t, labels, t_tensor.float())
# Get current alpha values
alpha_t = self.alphas[t]
alpha_bar_t = self.alphas_cumprod[t]
alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
# Calculate predicted x0 with more stable equations
pred_x0 = (x_t - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t)
# Direction pointing to x_t with reduced noise impact
pred_dir = torch.sqrt(1 - alpha_bar_t_prev) * pred_noise
# Dynamic noise scaling based on timestep
if t > 0:
noise_scale = 0.3 * (t / timesteps) # Reduce noise as we get closer to final image
noise = torch.randn_like(x_t) * noise_scale
else:
noise = torch.zeros_like(x_t)
# Update x_t with more stable combination
x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 + pred_dir + noise
# Progress callback
if progress_callback:
progress_callback((timesteps - t) / timesteps)
# Enhanced normalization with contrast adjustment
x_t = torch.clamp(x_t, -1, 1)
x_t = (x_t + 1) / 2 # Scale to [0,1]
# Post-processing directly in the tensor
x_t = self._post_process(x_t)
return x_t
def _post_process(self, image_tensor):
"""Apply simple post-processing to reduce noise"""
# Contrast adjustment
mean_val = image_tensor.mean()
image_tensor = (image_tensor - mean_val) * 1.2 + mean_val
# Mild Gaussian blur (implemented as depthwise convolution)
if hasattr(self, '_blur_kernel'):
blur_kernel = self._blur_kernel.to(image_tensor.device)
else:
blur_kernel = torch.tensor([
[0.05, 0.1, 0.05],
[0.1, 0.4, 0.1],
[0.05, 0.1, 0.05]
], dtype=torch.float32).view(1, 1, 3, 3).repeat(3, 1, 1, 1)
self._blur_kernel = blur_kernel
# Apply blur to each channel
padding = (1, 1, 1, 1)
image_tensor = torch.nn.functional.conv2d(
image_tensor.permute(0, 3, 1, 2), # NHWC to NCHW
blur_kernel,
padding=1,
groups=3
).permute(0, 2, 3, 1) # Back to NHWC
return torch.clamp(image_tensor, 0, 1)
def load_model(model_path, device):
unet = UNet(num_classes=NUM_CLASSES).to(device)
diffusion_model = DiffusionModel(unet).to(device)
if os.path.exists(model_path):
try:
checkpoint = torch.load(model_path, map_location=device)
# Handle both full model and state_dict loading
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint
# Handle both prefixed and non-prefixed state dicts
if all(k.startswith('model.') for k in state_dict.keys()):
state_dict = {k[6:]: v for k, v in state_dict.items()}
unet.load_state_dict(state_dict, strict=False)
print("Model loaded successfully")
# Verify model loading
test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
test_labels = torch.zeros(1, NUM_CLASSES).to(device)
test_time = torch.tensor([1]).to(device)
output = unet(test_input, test_labels, test_time)
print(f"Model test output shape: {output.shape}")
except Exception as e:
traceback.print_exc()
raise ValueError(f"Error loading model: {str(e)}")
else:
raise FileNotFoundError(f"Model weights not found at {model_path}")
diffusion_model.eval()
return diffusion_model
MODEL_NAME = "model_weights.pth"
model_path = MODEL_NAME
print("Loading model...")
try:
loaded_model = load_model(model_path, device)
print("Model loaded successfully!")
except Exception as e:
print(f"Failed to load model: {e}")
# Create a dummy model if loading fails
print("Creating dummy model for demonstration")
loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES)).to(device)
def cancel_generation():
cancel_event.set()
return "Generation cancelled"
def generate_images(label_str, num_images, progress=gr.Progress()):
global loaded_model
cancel_event.clear()
if num_images < 1 or num_images > 10:
raise gr.Error("Number of images must be between 1 and 10")
label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
if label_str not in label_map:
raise gr.Error("Invalid condition selected")
labels = torch.zeros(num_images, NUM_CLASSES, device=device)
labels[:, label_map[label_str]] = 1
try:
def progress_callback(progress_val):
progress(progress_val, desc="Generating...")
if cancel_event.is_set():
raise gr.Error("Generation was cancelled by user")
with torch.no_grad():
images = loaded_model.sample(
num_images=num_images,
timesteps=int(TIMESTEPS * 1.5), # More timesteps for cleaner images
img_size=IMG_SIZE,
num_classes=NUM_CLASSES,
labels=labels,
device=device,
progress_callback=progress_callback
)
if images is None:
return None, None
processed_images = []
for img in images:
img_np = img.cpu().numpy()
# Convert to PIL with enhanced contrast
img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
pil_img = Image.fromarray(img_np)
# Apply additional PIL-based enhancements
pil_img = pil_img.filter(ImageFilter.SMOOTH_MORE)
processed_images.append(pil_img)
if num_images == 1:
return processed_images[0], processed_images
else:
return None, processed_images
except Exception as e:
traceback.print_exc()
raise gr.Error(f"Generation failed: {str(e)}")
finally:
torch.cuda.empty_cache()
# Gradio UI
with gr.Blocks(theme=gr.themes.Soft(
primary_hue="violet",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Poppins")],
text_size="md"
)) as demo:
gr.Markdown("""
<center>
<h1>Synthetic X-ray Generator</h1>
<p><em>Generate synthetic chest X-rays conditioned on pathology</em></p>
</center>
""")
with gr.Row():
with gr.Column(scale=1):
condition = gr.Dropdown(
["Pneumonia", "Pneumothorax"],
label="Select Condition",
value="Pneumonia",
interactive=True
)
num_images = gr.Slider(
1, 10, value=1, step=1,
label="Number of Images",
interactive=True
)
with gr.Row():
submit_btn = gr.Button("Generate", variant="primary")
cancel_btn = gr.Button("Cancel", variant="stop")
gr.Markdown("""
<div style="text-align: center; margin-top: 10px;">
<small>Note: Generation may take several seconds per image</small>
</div>
""")
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("Output", id="output_tab"):
single_image = gr.Image(
label="Generated X-ray",
height=400,
visible=True
)
gallery = gr.Gallery(
label="Generated X-rays",
columns=3,
height="auto",
object_fit="contain",
visible=False
)
def update_ui_based_on_count(num_images):
if num_images == 1:
return {
single_image: gr.update(visible=True),
gallery: gr.update(visible=False)
}
else:
return {
single_image: gr.update(visible=False),
gallery: gr.update(visible=True)
}
num_images.change(
fn=update_ui_based_on_count,
inputs=num_images,
outputs=[single_image, gallery]
)
submit_btn.click(
fn=generate_images,
inputs=[condition, num_images],
outputs=[single_image, gallery]
)
cancel_btn.click(
fn=cancel_generation,
outputs=None
)
demo.css = """
.gradio-container {
background: linear-gradient(135deg, #f5f7fa 0%, #e4e8f0 100%);
}
.gallery-container {
background-color: white !important;
}
"""
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)