Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
import math | |
import os | |
from threading import Event | |
import traceback | |
# Constants | |
IMG_SIZE = 128 | |
TIMESTEPS = 500 | |
NUM_CLASSES = 2 | |
# Global Cancellation Flag | |
cancel_event = Event() | |
# Device Configuration | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# --- Model Definitions --- | |
class SinusoidalPositionEmbeddings(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
half_dim = dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) | |
self.register_buffer('embeddings', emb) | |
def forward(self, time): | |
embeddings = self.embeddings.to(time.device) | |
embeddings = time.float()[:, None] * embeddings[None, :] | |
return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1) | |
class UNet(nn.Module): | |
def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256): | |
super().__init__() | |
self.num_classes = num_classes | |
self.label_embedding = nn.Embedding(num_classes, time_dim) | |
self.time_mlp = nn.Sequential( | |
SinusoidalPositionEmbeddings(time_dim), | |
nn.Linear(time_dim, time_dim), | |
nn.ReLU(), | |
nn.Linear(time_dim, time_dim) | |
) | |
self.inc = self.double_conv(in_channels, 64) | |
self.down1 = self.down(64 + time_dim * 2, 128) | |
self.down2 = self.down(128 + time_dim * 2, 256) | |
self.down3 = self.down(256 + time_dim * 2, 512) | |
self.bottleneck = self.double_conv(512 + time_dim * 2, 1024) | |
self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2) | |
self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256) | |
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) | |
self.upconv2 = self.double_conv(128 + 128 + time_dim * 2, 128) | |
self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) | |
self.upconv3 = self.double_conv(64 + 64 + time_dim * 2, 64) | |
self.outc = nn.Conv2d(64, out_channels, kernel_size=1) | |
def double_conv(self, in_channels, out_channels): | |
return nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True) | |
) | |
def down(self, in_channels, out_channels): | |
return nn.Sequential( | |
nn.MaxPool2d(2), | |
self.double_conv(in_channels, out_channels) | |
) | |
def forward(self, x, labels, time): | |
label_indices = torch.argmax(labels, dim=1) | |
label_emb = self.label_embedding(label_indices) | |
t_emb = self.time_mlp(time) | |
combined_emb = torch.cat([t_emb, label_emb], dim=1) | |
combined_emb = combined_emb.unsqueeze(-1).unsqueeze(-1) | |
x1 = self.inc(x) | |
x1_cat = torch.cat([x1, combined_emb.repeat(1, 1, x1.shape[-2], x1.shape[-1])], dim=1) | |
x2 = self.down1(x1_cat) | |
x2_cat = torch.cat([x2, combined_emb.repeat(1, 1, x2.shape[-2], x2.shape[-1])], dim=1) | |
x3 = self.down2(x2_cat) | |
x3_cat = torch.cat([x3, combined_emb.repeat(1, 1, x3.shape[-2], x3.shape[-1])], dim=1) | |
x4 = self.down3(x3_cat) | |
x4_cat = torch.cat([x4, combined_emb.repeat(1, 1, x4.shape[-2], x4.shape[-1])], dim=1) | |
x5 = self.bottleneck(x4_cat) | |
x = self.up1(x5) | |
x = torch.cat([x, x3], dim=1) | |
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1) | |
x = self.upconv1(x) | |
x = self.up2(x) | |
x = torch.cat([x, x2], dim=1) | |
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1) | |
x = self.upconv2(x) | |
x = self.up3(x) | |
x = torch.cat([x, x1], dim=1) | |
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1) | |
x = self.upconv3(x) | |
output = self.outc(x) | |
return output | |
class DiffusionModel(nn.Module): | |
def __init__(self, model, timesteps=TIMESTEPS, time_dim=256): | |
super().__init__() | |
self.model = model | |
self.timesteps = timesteps | |
# More conservative noise schedule | |
scale = 1000 / timesteps | |
beta_start = 0.0001 | |
beta_end = 0.02 | |
self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)**1.5 | |
self.alphas = 1. - self.betas | |
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | |
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod)) | |
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod)) | |
def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None): | |
# Initialize with reduced noise scale | |
x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.7 | |
# Convert labels if needed | |
if labels.ndim == 1: | |
labels_one_hot = torch.zeros(num_images, num_classes, device=device) | |
labels_one_hot[torch.arange(num_images), labels] = 1 | |
labels = labels_one_hot | |
for t in reversed(range(timesteps)): | |
if cancel_event.is_set(): | |
return None | |
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long) | |
# Predict noise with model | |
pred_noise = self.model(x_t, labels, t_tensor.float()) | |
# Get current alpha values | |
alpha_t = self.alphas[t] | |
alpha_bar_t = self.alphas_cumprod[t] | |
alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0) | |
# Calculate predicted x0 with more stable equations | |
pred_x0 = (x_t - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t) | |
# Direction pointing to x_t with reduced noise impact | |
pred_dir = torch.sqrt(1 - alpha_bar_t_prev) * pred_noise | |
# Dynamic noise scaling based on timestep | |
if t > 0: | |
noise_scale = 0.3 * (t / timesteps) # Reduce noise as we get closer to final image | |
noise = torch.randn_like(x_t) * noise_scale | |
else: | |
noise = torch.zeros_like(x_t) | |
# Update x_t with more stable combination | |
x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 + pred_dir + noise | |
# Progress callback | |
if progress_callback: | |
progress_callback((timesteps - t) / timesteps) | |
# Enhanced normalization with contrast adjustment | |
x_t = torch.clamp(x_t, -1, 1) | |
x_t = (x_t + 1) / 2 # Scale to [0,1] | |
# Post-processing directly in the tensor | |
x_t = self._post_process(x_t) | |
return x_t | |
def _post_process(self, image_tensor): | |
"""Apply simple post-processing to reduce noise""" | |
# Contrast adjustment | |
mean_val = image_tensor.mean() | |
image_tensor = (image_tensor - mean_val) * 1.2 + mean_val | |
# Mild Gaussian blur (implemented as depthwise convolution) | |
if hasattr(self, '_blur_kernel'): | |
blur_kernel = self._blur_kernel.to(image_tensor.device) | |
else: | |
blur_kernel = torch.tensor([ | |
[0.05, 0.1, 0.05], | |
[0.1, 0.4, 0.1], | |
[0.05, 0.1, 0.05] | |
], dtype=torch.float32).view(1, 1, 3, 3).repeat(3, 1, 1, 1) | |
self._blur_kernel = blur_kernel | |
# Apply blur to each channel | |
padding = (1, 1, 1, 1) | |
image_tensor = torch.nn.functional.conv2d( | |
image_tensor.permute(0, 3, 1, 2), # NHWC to NCHW | |
blur_kernel, | |
padding=1, | |
groups=3 | |
).permute(0, 2, 3, 1) # Back to NHWC | |
return torch.clamp(image_tensor, 0, 1) | |
def load_model(model_path, device): | |
unet = UNet(num_classes=NUM_CLASSES).to(device) | |
diffusion_model = DiffusionModel(unet).to(device) | |
if os.path.exists(model_path): | |
try: | |
checkpoint = torch.load(model_path, map_location=device) | |
# Handle both full model and state_dict loading | |
if 'model_state_dict' in checkpoint: | |
state_dict = checkpoint['model_state_dict'] | |
else: | |
state_dict = checkpoint | |
# Handle both prefixed and non-prefixed state dicts | |
if all(k.startswith('model.') for k in state_dict.keys()): | |
state_dict = {k[6:]: v for k, v in state_dict.items()} | |
unet.load_state_dict(state_dict, strict=False) | |
print("Model loaded successfully") | |
# Verify model loading | |
test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device) | |
test_labels = torch.zeros(1, NUM_CLASSES).to(device) | |
test_time = torch.tensor([1]).to(device) | |
output = unet(test_input, test_labels, test_time) | |
print(f"Model test output shape: {output.shape}") | |
except Exception as e: | |
traceback.print_exc() | |
raise ValueError(f"Error loading model: {str(e)}") | |
else: | |
raise FileNotFoundError(f"Model weights not found at {model_path}") | |
diffusion_model.eval() | |
return diffusion_model | |
MODEL_NAME = "model_weights.pth" | |
model_path = MODEL_NAME | |
print("Loading model...") | |
try: | |
loaded_model = load_model(model_path, device) | |
print("Model loaded successfully!") | |
except Exception as e: | |
print(f"Failed to load model: {e}") | |
# Create a dummy model if loading fails | |
print("Creating dummy model for demonstration") | |
loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES)).to(device) | |
def cancel_generation(): | |
cancel_event.set() | |
return "Generation cancelled" | |
def generate_images(label_str, num_images, progress=gr.Progress()): | |
global loaded_model | |
cancel_event.clear() | |
if num_images < 1 or num_images > 10: | |
raise gr.Error("Number of images must be between 1 and 10") | |
label_map = {'Pneumonia': 0, 'Pneumothorax': 1} | |
if label_str not in label_map: | |
raise gr.Error("Invalid condition selected") | |
labels = torch.zeros(num_images, NUM_CLASSES, device=device) | |
labels[:, label_map[label_str]] = 1 | |
try: | |
def progress_callback(progress_val): | |
progress(progress_val, desc="Generating...") | |
if cancel_event.is_set(): | |
raise gr.Error("Generation was cancelled by user") | |
with torch.no_grad(): | |
images = loaded_model.sample( | |
num_images=num_images, | |
timesteps=int(TIMESTEPS * 1.5), # More timesteps for cleaner images | |
img_size=IMG_SIZE, | |
num_classes=NUM_CLASSES, | |
labels=labels, | |
device=device, | |
progress_callback=progress_callback | |
) | |
if images is None: | |
return None, None | |
processed_images = [] | |
for img in images: | |
img_np = img.cpu().numpy() | |
# Convert to PIL with enhanced contrast | |
img_np = (img_np * 255).clip(0, 255).astype(np.uint8) | |
pil_img = Image.fromarray(img_np) | |
# Apply additional PIL-based enhancements | |
pil_img = pil_img.filter(ImageFilter.SMOOTH_MORE) | |
processed_images.append(pil_img) | |
if num_images == 1: | |
return processed_images[0], processed_images | |
else: | |
return None, processed_images | |
except Exception as e: | |
traceback.print_exc() | |
raise gr.Error(f"Generation failed: {str(e)}") | |
finally: | |
torch.cuda.empty_cache() | |
# Gradio UI | |
with gr.Blocks(theme=gr.themes.Soft( | |
primary_hue="violet", | |
neutral_hue="slate", | |
font=[gr.themes.GoogleFont("Poppins")], | |
text_size="md" | |
)) as demo: | |
gr.Markdown(""" | |
<center> | |
<h1>Synthetic X-ray Generator</h1> | |
<p><em>Generate synthetic chest X-rays conditioned on pathology</em></p> | |
</center> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
condition = gr.Dropdown( | |
["Pneumonia", "Pneumothorax"], | |
label="Select Condition", | |
value="Pneumonia", | |
interactive=True | |
) | |
num_images = gr.Slider( | |
1, 10, value=1, step=1, | |
label="Number of Images", | |
interactive=True | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Generate", variant="primary") | |
cancel_btn = gr.Button("Cancel", variant="stop") | |
gr.Markdown(""" | |
<div style="text-align: center; margin-top: 10px;"> | |
<small>Note: Generation may take several seconds per image</small> | |
</div> | |
""") | |
with gr.Column(scale=2): | |
with gr.Tabs(): | |
with gr.TabItem("Output", id="output_tab"): | |
single_image = gr.Image( | |
label="Generated X-ray", | |
height=400, | |
visible=True | |
) | |
gallery = gr.Gallery( | |
label="Generated X-rays", | |
columns=3, | |
height="auto", | |
object_fit="contain", | |
visible=False | |
) | |
def update_ui_based_on_count(num_images): | |
if num_images == 1: | |
return { | |
single_image: gr.update(visible=True), | |
gallery: gr.update(visible=False) | |
} | |
else: | |
return { | |
single_image: gr.update(visible=False), | |
gallery: gr.update(visible=True) | |
} | |
num_images.change( | |
fn=update_ui_based_on_count, | |
inputs=num_images, | |
outputs=[single_image, gallery] | |
) | |
submit_btn.click( | |
fn=generate_images, | |
inputs=[condition, num_images], | |
outputs=[single_image, gallery] | |
) | |
cancel_btn.click( | |
fn=cancel_generation, | |
outputs=None | |
) | |
demo.css = """ | |
.gradio-container { | |
background: linear-gradient(135deg, #f5f7fa 0%, #e4e8f0 100%); | |
} | |
.gallery-container { | |
background-color: white !important; | |
} | |
""" | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |