Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
import math | |
import os | |
from threading import Event | |
import traceback | |
# Constants | |
IMG_SIZE = 128 | |
TIMESTEPS = 300 | |
NUM_CLASSES = 2 | |
# Global Cancellation Flag | |
cancel_event = Event() | |
# Device Configuration | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# --- Model Definitions (from second file) --- | |
class SinusoidalPositionEmbeddings(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
self.register_buffer('embeddings', self._precompute_embeddings(dim)) | |
def _precompute_embeddings(self, dim): | |
half_dim = dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim) * -emb) | |
return emb | |
def forward(self, time): | |
device = time.device | |
embeddings = self.embeddings.to(device) | |
embeddings = time[:, None] * embeddings[None, :] | |
output = torch.cat([embeddings.sin(), embeddings.cos()], dim=-1) | |
return output | |
class UNet(nn.Module): | |
def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256): | |
super().__init__() | |
self.num_classes = num_classes | |
self.label_embedding = nn.Embedding(num_classes, time_dim) | |
self.time_mlp = nn.Sequential( | |
SinusoidalPositionEmbeddings(time_dim), | |
nn.Linear(time_dim, time_dim), | |
nn.ReLU(), | |
nn.Linear(time_dim, time_dim) | |
) | |
self.inc = self.double_conv(in_channels, 64) | |
self.down1 = self.down(64 + time_dim * 2, 128) | |
self.down2 = self.down(128 + time_dim * 2, 256) | |
self.down3 = self.down(256 + time_dim * 2, 512) | |
self.bottleneck = self.double_conv(512 + time_dim * 2, 1024) | |
self.up1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2) | |
self.upconv1 = self.double_conv(256 + 256 + time_dim * 2, 256) | |
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) | |
self.upconv2 = self.double_conv(128 + 128 + time_dim * 2, 128) | |
self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) | |
self.upconv3 = self.double_conv(64 + 64 + time_dim * 2, 64) | |
self.outc = nn.Conv2d(64, out_channels, kernel_size=1) | |
def double_conv(self, in_channels, out_channels): | |
return nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True) | |
) | |
def down(self, in_channels, out_channels): | |
return nn.Sequential( | |
nn.MaxPool2d(2), | |
self.double_conv(in_channels, out_channels) | |
) | |
def forward(self, x, labels, time): | |
label_indices = torch.argmax(labels, dim=1) | |
label_emb = self.label_embedding(label_indices) | |
t_emb = self.time_mlp(time) | |
combined_emb = torch.cat([t_emb, label_emb], dim=1) | |
combined_emb = combined_emb.unsqueeze(-1).unsqueeze(-1) | |
x1 = self.inc(x) | |
x1_cat = torch.cat([x1, combined_emb.repeat(1, 1, x1.shape[-2], x1.shape[-1])], dim=1) | |
x2 = self.down1(x1_cat) | |
x2_cat = torch.cat([x2, combined_emb.repeat(1, 1, x2.shape[-2], x2.shape[-1])], dim=1) | |
x3 = self.down2(x2_cat) | |
x3_cat = torch.cat([x3, combined_emb.repeat(1, 1, x3.shape[-2], x3.shape[-1])], dim=1) | |
x4 = self.down3(x3_cat) | |
x4_cat = torch.cat([x4, combined_emb.repeat(1, 1, x4.shape[-2], x4.shape[-1])], dim=1) | |
x5 = self.bottleneck(x4_cat) | |
x = self.up1(x5) | |
x = torch.cat([x, x3], dim=1) | |
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1) | |
x = self.upconv1(x) | |
x = self.up2(x) | |
x = torch.cat([x, x2], dim=1) | |
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1) | |
x = self.upconv2(x) | |
x = self.up3(x) | |
x = torch.cat([x, x1], dim=1) | |
x = torch.cat([x, combined_emb.repeat(1, 1, x.shape[-2], x.shape[-1])], dim=1) | |
x = self.upconv3(x) | |
output = self.outc(x) | |
return output | |
class DiffusionModel(nn.Module): | |
def __init__(self, model, timesteps=500, time_dim=256): | |
super().__init__() | |
self.model = model | |
self.timesteps = timesteps | |
self.time_dim = time_dim | |
self.betas = self.linear_schedule(timesteps) | |
self.alphas = 1. - self.betas | |
self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0).float()) | |
def linear_schedule(self, timesteps): | |
scale = 1000 / timesteps | |
beta_start = scale * 0.0001 | |
beta_end = scale * 0.02 | |
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) | |
def forward_diffusion(self, x_0, t, noise): | |
x_0 = x_0.float() | |
noise = noise.float() | |
alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1) | |
x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1. - alpha_bar_t) * noise | |
return x_t | |
def forward(self, x_0, labels): | |
t = torch.randint(0, self.timesteps, (x_0.shape[0],), device=x_0.device).long() | |
noise = torch.randn_like(x_0) | |
x_t = self.forward_diffusion(x_0, t, noise) | |
predicted_noise = self.model(x_t, labels, t.float()) | |
return predicted_noise, noise, t | |
def sample(model, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None): | |
x_t = torch.randn(num_images, 3, img_size, img_size).to(device) | |
if labels.ndim == 1: | |
labels_one_hot = torch.zeros(num_images, num_classes).to(device) | |
labels_one_hot[torch.arange(num_images), labels] = 1 | |
labels = labels_one_hot | |
else: | |
labels = labels.to(device) | |
for t in reversed(range(timesteps)): | |
if cancel_event.is_set(): | |
return None | |
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float) | |
predicted_noise = model.model(x_t, labels, t_tensor) | |
beta_t = model.betas[t].to(device) | |
alpha_t = model.alphas[t].to(device) | |
alpha_bar_t = model.alpha_bars[t].to(device) | |
mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise) | |
variance = beta_t | |
if t > 0: | |
noise = torch.randn_like(x_t) | |
else: | |
noise = torch.zeros_like(x_t) | |
x_t = mean + torch.sqrt(variance) * noise | |
if progress_callback: | |
progress_callback((timesteps - t) / timesteps) | |
x_0 = torch.clamp(x_t, -1., 1.) | |
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) | |
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) | |
x_0 = std * x_0 + mean | |
x_0 = torch.clamp(x_0, 0., 1.) | |
return x_0 | |
def load_model(model_path, device): | |
unet_model = UNet(num_classes=NUM_CLASSES).to(device) | |
diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device) | |
try: | |
checkpoint = torch.load(model_path, map_location=device) | |
if 'model_state_dict' in checkpoint: | |
diffusion_model.model.load_state_dict(checkpoint['model_state_dict']) | |
else: | |
diffusion_model.model.load_state_dict(checkpoint) | |
print(f"Successfully loaded model from {model_path}") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
print("Using randomly initialized weights") | |
diffusion_model.eval() | |
return diffusion_model | |
def cancel_generation(): | |
cancel_event.set() | |
return "Generation cancelled" | |
def generate_single_image(label_str): | |
label_map = {'Pneumonia': 0, 'Pneumothorax': 1} | |
try: | |
label_index = label_map[label_str] | |
except KeyError: | |
raise gr.Error(f"Invalid label '{label_str}'. Please select either 'Pneumonia' or 'Pneumothorax'.") | |
labels = torch.zeros(1, NUM_CLASSES, device=device) | |
labels[0, label_index] = 1 | |
with torch.no_grad(): | |
generated_image = sample( | |
model=loaded_model, | |
num_images=1, | |
timesteps=TIMESTEPS, | |
img_size=IMG_SIZE, | |
num_classes=NUM_CLASSES, | |
labels=labels, | |
device=device | |
) | |
img_np = generated_image.squeeze(0).cpu().permute(1, 2, 0).numpy() | |
img_np = np.clip(img_np, 0, 1) | |
img_pil = Image.fromarray((img_np * 255).astype(np.uint8)) | |
return img_pil | |
def generate_batch_images(label_str, num_images, progress=gr.Progress()): | |
global loaded_model | |
cancel_event.clear() | |
if num_images < 1 or num_images > 10: | |
raise gr.Error("Number of images must be between 1 and 10") | |
label_map = {'Pneumonia': 0, 'Pneumothorax': 1} | |
if label_str not in label_map: | |
raise gr.Error("Invalid condition selected") | |
labels = torch.zeros(num_images, NUM_CLASSES, device=device) | |
labels[:, label_map[label_str]] = 1 | |
try: | |
def progress_callback(progress_val): | |
progress(progress_val, desc="Generating...") | |
if cancel_event.is_set(): | |
raise gr.Error("Generation was cancelled by user") | |
with torch.no_grad(): | |
images = sample( | |
model=loaded_model, | |
num_images=num_images, | |
timesteps=TIMESTEPS, | |
img_size=IMG_SIZE, | |
num_classes=NUM_CLASSES, | |
labels=labels, | |
device=device, | |
progress_callback=progress_callback | |
) | |
if images is None: | |
return None | |
processed_images = [] | |
for img in images: | |
img_np = img.cpu().permute(1, 2, 0).numpy() | |
img_np = np.clip(img_np, 0, 1) | |
pil_img = Image.fromarray((img_np * 255).astype(np.uint8)) | |
processed_images.append(pil_img) | |
return processed_images | |
except torch.cuda.OutOfMemoryError: | |
torch.cuda.empty_cache() | |
raise gr.Error("Out of GPU memory - try generating fewer images") | |
except Exception as e: | |
traceback.print_exc() | |
if str(e) != "Generation was cancelled by user": | |
raise gr.Error(f"Generation failed: {str(e)}") | |
return None | |
finally: | |
torch.cuda.empty_cache() | |
# Load model | |
MODEL_DIR = "models" | |
MODEL_NAME = "diffusion_unet_xray.pth" | |
model_path = os.path.join(MODEL_DIR, MODEL_NAME) | |
print("Loading model...") | |
loaded_model = load_model(model_path, device) | |
print("Model loaded successfully!") | |
# --- Gradio UI (from first file with modifications) --- | |
with gr.Blocks(theme=gr.themes.Soft( | |
primary_hue="violet", | |
neutral_hue="slate", | |
font=[gr.themes.GoogleFont("Poppins")], | |
text_size="md" | |
)) as demo: | |
gr.Markdown(""" | |
<center> | |
<h1>Synthetic X-ray Generator</h1> | |
<p><em>Generate synthetic chest X-rays conditioned on pathology</em></p> | |
</center> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
condition = gr.Dropdown( | |
["Pneumonia", "Pneumothorax"], | |
label="Select Condition", | |
value="Pneumonia", | |
interactive=True | |
) | |
num_images = gr.Slider( | |
1, 10, value=1, step=1, | |
label="Number of Images", | |
interactive=True | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Generate", variant="primary") | |
cancel_btn = gr.Button("Cancel", variant="stop") | |
gr.Markdown(""" | |
<div style="text-align: center; margin-top: 10px;"> | |
<small>Note: Generation may take several seconds per image</small> | |
</div> | |
""") | |
with gr.Column(scale=2): | |
with gr.Tab("Single Image"): | |
single_image = gr.Image( | |
type="pil", | |
label="Generated X-ray", | |
height=400 | |
) | |
with gr.Tab("Batch Images"): | |
gallery = gr.Gallery( | |
label="Generated X-rays", | |
columns=3, | |
height="auto", | |
object_fit="contain" | |
) | |
# Single image generation | |
condition.change( | |
fn=generate_single_image, | |
inputs=condition, | |
outputs=single_image | |
) | |
# Batch image generation | |
submit_btn.click( | |
fn=generate_batch_images, | |
inputs=[condition, num_images], | |
outputs=gallery | |
) | |
cancel_btn.click( | |
fn=cancel_generation, | |
outputs=None | |
) | |
demo.css = """ | |
.gradio-container { | |
background: linear-gradient(135deg, #f5f7fa 0%, #e4e8f0 100%); | |
} | |
.gallery-container { | |
background-color: white !important; | |
} | |
""" | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |