import torch import torch.nn as nn import gradio as gr from PIL import Image import numpy as np import math import os from threading import Event import traceback # Constants IMG_SIZE = 128 TIMESTEPS = 500 NUM_CLASSES = 2 # Global Cancellation Flag cancel_event = Event() # Device Configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- Model Definitions --- class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim half_dim = dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) self.register_buffer('embeddings', emb) def forward(self, time): embeddings = self.embeddings.to(time.device) embeddings = time.float()[:, None] * embeddings[None, :] return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1) class UNet(nn.Module): def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256): super().__init__() self.num_classes = num_classes self.label_embedding = nn.Embedding(num_classes, time_dim) self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(time_dim), nn.Linear(time_dim, time_dim), nn.ReLU(), nn.Linear(time_dim, time_dim) ) self.inc = self.double_conv(in_channels, 64) self.down1 = self.down(64 + time_dim * 2, 128) self.down2 = self.down(128 + time_dim * 2, 256) self.down3 = self.down(256 + time_dim * 2, 512) self.bottleneck = self.double_conv(512 + time_dim * 2, 1024) self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2) self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256) self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) self.upconv2 = self.double_conv(128 + 128 + time_dim * 2, 128) self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) self.upconv3 = self.double_conv(64 + 64 + time_dim * 2, 64) self.outc = nn.Conv2d(64, out_channels, kernel_size=1) def double_conv(self, in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) def down(self, in_channels, out_channels): return nn.Sequential( nn.MaxPool2d(2), self.double_conv(in_channels, out_channels) ) def forward(self, x, labels, time): label_indices = torch.argmax(labels, dim=1) label_emb = self.label_embedding(label_indices) t_emb = self.time_mlp(time) combined_emb = torch.cat([t_emb, label_emb], dim=1) combined_emb = combined_emb.unsqueeze(-1).unsqueeze(-1) x1 = self.inc(x) x1_cat = torch.cat([x1, combined_emb.repeat(1, 1, x1.shape[-2], x1.shape[-1])], dim=1) x2 = self.down1(x1_cat) x2_cat = torch.cat([x2, combined_emb.repeat(1, 1, x2.shape[-2], x2.shape[-1])], dim=1) x3 = self.down2(x2_cat) x3_cat = torch.cat([x3, combined_emb.repeat(1, 1, x3.shape[-2], x3.shape[-1])], dim=1) x4 = self.down3(x3_cat) x4_cat = torch.cat([x4, combined_emb.repeat(1, 1, x4.shape[-2], x4.shape[-1])], dim=1) x5 = self.bottleneck(x4_cat) x = self.up1(x5) x = torch.cat([x, x3], dim=1) x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1) x = self.upconv1(x) x = self.up2(x) x = torch.cat([x, x2], dim=1) x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1) x = self.upconv2(x) x = self.up3(x) x = torch.cat([x, x1], dim=1) x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1) x = self.upconv3(x) output = self.outc(x) return output class DiffusionModel(nn.Module): def __init__(self, model, timesteps=TIMESTEPS, time_dim=256): super().__init__() self.model = model self.timesteps = timesteps # More conservative noise schedule scale = 1000 / timesteps beta_start = 0.0001 beta_end = 0.02 self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)**1.5 self.alphas = 1. - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod)) self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod)) @torch.no_grad() def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None): # Initialize with reduced noise scale x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.7 # Convert labels if needed if labels.ndim == 1: labels_one_hot = torch.zeros(num_images, num_classes, device=device) labels_one_hot[torch.arange(num_images), labels] = 1 labels = labels_one_hot for t in reversed(range(timesteps)): if cancel_event.is_set(): return None t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long) # Predict noise with model pred_noise = self.model(x_t, labels, t_tensor.float()) # Get current alpha values alpha_t = self.alphas[t] alpha_bar_t = self.alphas_cumprod[t] alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0) # Calculate predicted x0 with more stable equations pred_x0 = (x_t - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t) # Direction pointing to x_t with reduced noise impact pred_dir = torch.sqrt(1 - alpha_bar_t_prev) * pred_noise # Dynamic noise scaling based on timestep if t > 0: noise_scale = 0.3 * (t / timesteps) # Reduce noise as we get closer to final image noise = torch.randn_like(x_t) * noise_scale else: noise = torch.zeros_like(x_t) # Update x_t with more stable combination x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 + pred_dir + noise # Progress callback if progress_callback: progress_callback((timesteps - t) / timesteps) # Enhanced normalization with contrast adjustment x_t = torch.clamp(x_t, -1, 1) x_t = (x_t + 1) / 2 # Scale to [0,1] # Post-processing directly in the tensor x_t = self._post_process(x_t) return x_t def _post_process(self, image_tensor): """Apply simple post-processing to reduce noise""" # Contrast adjustment mean_val = image_tensor.mean() image_tensor = (image_tensor - mean_val) * 1.2 + mean_val # Mild Gaussian blur (implemented as depthwise convolution) if hasattr(self, '_blur_kernel'): blur_kernel = self._blur_kernel.to(image_tensor.device) else: blur_kernel = torch.tensor([ [0.05, 0.1, 0.05], [0.1, 0.4, 0.1], [0.05, 0.1, 0.05] ], dtype=torch.float32).view(1, 1, 3, 3).repeat(3, 1, 1, 1) self._blur_kernel = blur_kernel # Apply blur to each channel padding = (1, 1, 1, 1) image_tensor = torch.nn.functional.conv2d( image_tensor.permute(0, 3, 1, 2), # NHWC to NCHW blur_kernel, padding=1, groups=3 ).permute(0, 2, 3, 1) # Back to NHWC return torch.clamp(image_tensor, 0, 1) def load_model(model_path, device): unet = UNet(num_classes=NUM_CLASSES).to(device) diffusion_model = DiffusionModel(unet).to(device) if os.path.exists(model_path): try: checkpoint = torch.load(model_path, map_location=device) # Handle both full model and state_dict loading if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] else: state_dict = checkpoint # Handle both prefixed and non-prefixed state dicts if all(k.startswith('model.') for k in state_dict.keys()): state_dict = {k[6:]: v for k, v in state_dict.items()} unet.load_state_dict(state_dict, strict=False) print("Model loaded successfully") # Verify model loading test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device) test_labels = torch.zeros(1, NUM_CLASSES).to(device) test_time = torch.tensor([1]).to(device) output = unet(test_input, test_labels, test_time) print(f"Model test output shape: {output.shape}") except Exception as e: traceback.print_exc() raise ValueError(f"Error loading model: {str(e)}") else: raise FileNotFoundError(f"Model weights not found at {model_path}") diffusion_model.eval() return diffusion_model MODEL_NAME = "model_weights.pth" model_path = MODEL_NAME print("Loading model...") try: loaded_model = load_model(model_path, device) print("Model loaded successfully!") except Exception as e: print(f"Failed to load model: {e}") # Create a dummy model if loading fails print("Creating dummy model for demonstration") loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES)).to(device) def cancel_generation(): cancel_event.set() return "Generation cancelled" def generate_images(label_str, num_images, progress=gr.Progress()): global loaded_model cancel_event.clear() if num_images < 1 or num_images > 10: raise gr.Error("Number of images must be between 1 and 10") label_map = {'Pneumonia': 0, 'Pneumothorax': 1} if label_str not in label_map: raise gr.Error("Invalid condition selected") labels = torch.zeros(num_images, NUM_CLASSES, device=device) labels[:, label_map[label_str]] = 1 try: def progress_callback(progress_val): progress(progress_val, desc="Generating...") if cancel_event.is_set(): raise gr.Error("Generation was cancelled by user") with torch.no_grad(): images = loaded_model.sample( num_images=num_images, timesteps=int(TIMESTEPS * 1.5), # More timesteps for cleaner images img_size=IMG_SIZE, num_classes=NUM_CLASSES, labels=labels, device=device, progress_callback=progress_callback ) if images is None: return None, None processed_images = [] for img in images: img_np = img.cpu().numpy() # Convert to PIL with enhanced contrast img_np = (img_np * 255).clip(0, 255).astype(np.uint8) pil_img = Image.fromarray(img_np) # Apply additional PIL-based enhancements pil_img = pil_img.filter(ImageFilter.SMOOTH_MORE) processed_images.append(pil_img) if num_images == 1: return processed_images[0], processed_images else: return None, processed_images except Exception as e: traceback.print_exc() raise gr.Error(f"Generation failed: {str(e)}") finally: torch.cuda.empty_cache() # Gradio UI with gr.Blocks(theme=gr.themes.Soft( primary_hue="violet", neutral_hue="slate", font=[gr.themes.GoogleFont("Poppins")], text_size="md" )) as demo: gr.Markdown("""
Generate synthetic chest X-rays conditioned on pathology