Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
import math | |
import os | |
from threading import Event | |
import traceback | |
import cv2 # Added for bilateral filtering | |
# Constants | |
IMG_SIZE = 128 | |
TIMESTEPS = 300 # From second code | |
NUM_CLASSES = 2 | |
# Global Cancellation Flag | |
cancel_event = Event() | |
# Device Configuration | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# --- Model Definitions --- | |
class SinusoidalPositionEmbeddings(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
half_dim = dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim) * -emb) # From second code (no dtype specified) | |
self.register_buffer('embeddings', emb) | |
def forward(self, time): | |
device = time.device # From second code | |
embeddings = self.embeddings.to(device) | |
embeddings = time[:, None] * embeddings[None, :] # From second code | |
return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1) | |
class UNet(nn.Module): | |
def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256): | |
super().__init__() | |
self.num_classes = num_classes | |
self.label_embedding = nn.Embedding(num_classes, time_dim) | |
self.time_mlp = nn.Sequential( | |
SinusoidalPositionEmbeddings(time_dim), | |
nn.Linear(time_dim, time_dim), | |
nn.ReLU(), | |
nn.Linear(time_dim, time_dim) | |
) | |
# Encoder | |
self.inc = self.double_conv(in_channels, 64) | |
self.down1 = self.down(64 + time_dim * 2, 128) | |
self.down2 = self.down(128 + time_dim * 2, 256) | |
self.down3 = self.down(256 + time_dim * 2, 512) | |
# Bottleneck | |
self.bottleneck = self.double_conv(512 + time_dim * 2, 1024) | |
# Decoder | |
self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2) | |
self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256) | |
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) | |
self.upconv2 = self.double_conv(128 + 128 + time_dim * 2, 128) | |
self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) | |
self.upconv3 = self.double_conv(64 + 64 + time_dim * 2, 64) | |
self.outc = nn.Conv2d(64, out_channels, kernel_size=1) | |
def double_conv(self, in_channels, out_channels): | |
return nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True) | |
) | |
def down(self, in_channels, out_channels): | |
return nn.Sequential( | |
nn.MaxPool2d(2), | |
self.double_conv(in_channels, out_channels) | |
) | |
def forward(self, x, labels, time): | |
label_indices = torch.argmax(labels, dim=1) | |
label_emb = self.label_embedding(label_indices) | |
t_emb = self.time_mlp(time) | |
combined_emb = torch.cat([t_emb, label_emb], dim=1) | |
combined_emb = combined_emb.unsqueeze(-1).unsqueeze(-1) | |
x1 = self.inc(x) | |
x1_cat = torch.cat([x1, combined_emb.repeat(1, 1, x1.shape[-2], x1.shape[-1])], dim=1) | |
x2 = self.down1(x1_cat) | |
x2_cat = torch.cat([x2, combined_emb.repeat(1, 1, x2.shape[-2], x2.shape[-1])], dim=1) | |
x3 = self.down2(x2_cat) | |
x3_cat = torch.cat([x3, combined_emb.repeat(1, 1, x3.shape[-2], x3.shape[-1])], dim=1) | |
x4 = self.down3(x3_cat) | |
x4_cat = torch.cat([x4, combined_emb.repeat(1, 1, x4.shape[-2], x4.shape[-1])], dim=1) | |
x5 = self.bottleneck(x4_cat) | |
x = self.up1(x5) | |
x = torch.cat([x, x3], dim=1) | |
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1) | |
x = self.upconv1(x) | |
x = self.up2(x) | |
x = torch.cat([x, x2], dim=1) | |
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1) | |
x = self.upconv2(x) | |
x = self.up3(x) | |
x = torch.cat([x, x1], dim=1) | |
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1) | |
x = self.upconv3(x) | |
return self.outc(x) | |
class DiffusionModel(nn.Module): | |
def __init__(self, model, timesteps=TIMESTEPS, time_dim=256): | |
super().__init__() | |
self.model = model | |
self.timesteps = timesteps | |
self.time_dim = time_dim | |
# Linear beta schedule with scaling from second code | |
scale = 1000 / timesteps | |
beta_start = scale * 0.0001 | |
beta_end = scale * 0.02 | |
self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) | |
self.alphas = 1. - self.betas | |
self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0).float()) | |
def forward_diffusion(self, x_0, t, noise): | |
x_0 = x_0.float() | |
noise = noise.float() | |
alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1) | |
x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1. - alpha_bar_t) * noise | |
return x_t | |
def forward(self, x_0, labels): | |
t = torch.randint(0, self.timesteps, (x_0.shape[0],), device=x_0.device).long() | |
noise = torch.randn_like(x_0) | |
x_t = self.forward_diffusion(x_0, t, noise) | |
predicted_noise = self.model(x_t, labels, t.float()) | |
return predicted_noise, noise, t | |
def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None): | |
# Start with random noise | |
x_t = torch.randn(num_images, 3, img_size, img_size).to(device) | |
# Label handling (one-hot if needed) | |
if labels.ndim == 1: | |
labels_one_hot = torch.zeros(num_images, num_classes).to(device) | |
labels_one_hot[torch.arange(num_images), labels] = 1 | |
labels = labels_one_hot | |
else: | |
labels = labels.to(device) | |
# REVERTED SAMPLING LOOP WITH NOISE REDUCTION | |
for t in reversed(range(self.timesteps)): | |
if cancel_event.is_set(): | |
return None | |
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float) | |
predicted_noise = self.model(x_t, labels, t_tensor) | |
# Calculate coefficients | |
beta_t = self.betas[t].to(device) | |
alpha_t = self.alphas[t].to(device) | |
alpha_bar_t = self.alpha_bars[t].to(device) | |
mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise) | |
variance = beta_t | |
# Reduced noise injection with lower multiplier | |
if t > 0: | |
noise = torch.randn_like(x_t) * 0.8 # Reduced noise by 20% | |
else: | |
noise = torch.zeros_like(x_t) | |
x_t = mean + torch.sqrt(variance) * noise | |
if progress_callback: | |
progress_callback((self.timesteps - t) / self.timesteps) | |
# Clamp and denormalize | |
x_0 = torch.clamp(x_t, -1., 1.) | |
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) | |
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) | |
x_0 = std * x_0 + mean | |
x_0 = torch.clamp(x_0, 0., 1.) | |
# ENHANCED SHARPENING | |
# First apply mild bilateral filtering to reduce noise while preserving edges | |
x_np = x_0.cpu().permute(0, 2, 3, 1).numpy() | |
filtered = [] | |
for img in x_np: | |
img = (img * 255).astype(np.uint8) | |
filtered_img = cv2.bilateralFilter(img, d=5, sigmaColor=15, sigmaSpace=15) | |
filtered.append(filtered_img / 255.0) | |
x_0 = torch.tensor(np.array(filtered), device=device).permute(0, 3, 1, 2) | |
# Then apply stronger unsharp masking | |
kernel = torch.ones(3, 1, 5, 5, device=device) / 75 | |
kernel = kernel.to(x_0.dtype) | |
blurred = torch.nn.functional.conv2d( | |
x_0, | |
kernel, | |
padding=2, | |
groups=3 | |
) | |
x_0 = torch.clamp(1.5 * x_0 - 0.5 * blurred, 0., 1.) # Increased sharpening factor | |
return x_0 | |
def load_model(model_path, device): | |
unet_model = UNet(num_classes=NUM_CLASSES).to(device) | |
diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device) | |
if os.path.exists(model_path): | |
checkpoint = torch.load(model_path, map_location=device) | |
if 'model_state_dict' in checkpoint: | |
# Handle training checkpoint format | |
state_dict = { | |
k[6:]: v for k, v in checkpoint['model_state_dict'].items() | |
if k.startswith('model.') | |
} | |
# Load UNet weights | |
unet_model.load_state_dict(state_dict, strict=False) | |
# Initialize diffusion model with loaded UNet | |
diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device) | |
print(f"Loaded UNet weights from {model_path}") | |
else: | |
# Handle direct model weights format | |
try: | |
# First try loading full DiffusionModel | |
diffusion_model.load_state_dict(checkpoint) | |
print(f"Loaded full DiffusionModel from {model_path}") | |
except RuntimeError: | |
# If that fails, load just the UNet weights | |
unet_model.load_state_dict(checkpoint, strict=False) | |
diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device) | |
print(f"Loaded UNet weights only from {model_path}") | |
else: | |
print(f"Weights file not found at {model_path}") | |
print("Using randomly initialized weights") | |
diffusion_model.eval() | |
return diffusion_model | |
def cancel_generation(): | |
cancel_event.set() | |
return "Generation cancelled" | |
def generate_images(label_str, num_images, progress=gr.Progress()): | |
global loaded_model | |
cancel_event.clear() | |
if num_images < 1 or num_images > 10: | |
raise gr.Error("Number of images must be between 1 and 10") | |
label_map = {'Pneumonia': 0, 'Pneumothorax': 1} | |
if label_str not in label_map: | |
raise gr.Error("Invalid condition selected") | |
labels = torch.zeros(num_images, NUM_CLASSES) | |
labels[:, label_map[label_str]] = 1 | |
try: | |
def progress_callback(progress_val): | |
progress(progress_val, desc="Generating...") | |
if cancel_event.is_set(): | |
raise gr.Error("Generation was cancelled by user") | |
with torch.no_grad(): | |
images = loaded_model.sample( | |
num_images=num_images, | |
img_size=IMG_SIZE, | |
num_classes=NUM_CLASSES, | |
labels=labels, | |
device=device, | |
progress_callback=progress_callback | |
) | |
if images is None: | |
return None, None | |
processed_images = [] | |
for img in images: | |
img_np = img.cpu().permute(1, 2, 0).numpy() | |
img_np = (img_np * 255).clip(0, 255).astype(np.uint8) | |
pil_img = Image.fromarray(img_np) | |
processed_images.append(pil_img) | |
if num_images == 1: | |
return processed_images[0], processed_images | |
else: | |
return None, processed_images | |
except Exception as e: | |
traceback.print_exc() | |
raise gr.Error(f"Generation failed: {str(e)}") | |
finally: | |
torch.cuda.empty_cache() | |
# Load model | |
MODEL_NAME = "model_weights.pth" | |
model_path = MODEL_NAME | |
print("Loading model...") | |
try: | |
loaded_model = load_model(model_path, device) | |
print("Model loaded successfully!") | |
except Exception as e: | |
print(f"Failed to load model: {e}") | |
print("Creating dummy model for demonstration") | |
loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES), timesteps=TIMESTEPS).to(device) | |
# Gradio UI (from first code) | |
with gr.Blocks(theme=gr.themes.Soft( | |
primary_hue="violet", | |
neutral_hue="slate", | |
font=[gr.themes.GoogleFont("Poppins")], | |
text_size="md" | |
)) as demo: | |
gr.Markdown(""" | |
<center> | |
<h1>Synthetic X-ray Generator</h1> | |
<p><em>Generate synthetic chest X-rays conditioned on pathology</em></p> | |
</center> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
condition = gr.Dropdown( | |
["Pneumonia", "Pneumothorax"], | |
label="Select Condition", | |
value="Pneumonia", | |
interactive=True | |
) | |
num_images = gr.Slider( | |
1, 10, value=1, step=1, | |
label="Number of Images", | |
interactive=True | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Generate", variant="primary") | |
cancel_btn = gr.Button("Cancel", variant="stop") | |
gr.Markdown(""" | |
<div style="text-align: center; margin-top: 10px;"> | |
<small>Note: Generation may take several seconds per image</small> | |
</div> | |
""") | |
with gr.Column(scale=2): | |
with gr.Tabs(): | |
with gr.TabItem("Output", id="output_tab"): | |
single_image = gr.Image( | |
label="Generated X-ray", | |
height=400, | |
visible=True | |
) | |
gallery = gr.Gallery( | |
label="Generated X-rays", | |
columns=3, | |
height="auto", | |
object_fit="contain", | |
visible=False | |
) | |
def update_ui_based_on_count(num_images): | |
if num_images == 1: | |
return { | |
single_image: gr.update(visible=True), | |
gallery: gr.update(visible=False) | |
} | |
else: | |
return { | |
single_image: gr.update(visible=False), | |
gallery: gr.update(visible=True) | |
} | |
num_images.change( | |
fn=update_ui_based_on_count, | |
inputs=num_images, | |
outputs=[single_image, gallery] | |
) | |
submit_btn.click( | |
fn=generate_images, | |
inputs=[condition, num_images], | |
outputs=[single_image, gallery] | |
) | |
cancel_btn.click( | |
fn=cancel_generation, | |
outputs=None | |
) | |
demo.css = """ | |
.gradio-container { | |
background: linear-gradient(135deg, #f5f7fa 0%, #e4e8f0 100%); | |
} | |
.gallery-container { | |
background-color: white !important; | |
} | |
""" | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |