Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
import math | |
import os | |
# Constants (update these to match your training config) | |
IMG_SIZE = 128 | |
TIMESTEPS = 300 | |
NUM_CLASSES = 2 | |
# Define the device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# SinusoidalPositionEmbeddings and UNet classes remain the same as your original code | |
# DiffusionModel class remains the same as your original code | |
# Load the trained model with improved error handling | |
def load_model(model_path, device): | |
unet_model = UNet(num_classes=NUM_CLASSES).to(device) | |
diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device) | |
try: | |
checkpoint = torch.load(model_path, map_location=device) | |
# Handle both full model and state_dict loading | |
if 'model_state_dict' in checkpoint: | |
diffusion_model.model.load_state_dict(checkpoint['model_state_dict']) | |
else: | |
diffusion_model.model.load_state_dict(checkpoint) | |
print(f"Successfully loaded model from {model_path}") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
print("Using randomly initialized weights") | |
diffusion_model.eval() | |
return diffusion_model | |
# Improved image generation function | |
def generate_image(label_str): | |
label_map = {'Pneumonia': 0, 'Pneumothorax': 1} | |
try: | |
label_index = label_map[label_str] | |
except KeyError: | |
raise gr.Error(f"Invalid label '{label_str}'. Please select either 'Pneumonia' or 'Pneumothorax'.") | |
# Create one-hot encoded label | |
labels = torch.zeros(1, NUM_CLASSES, device=device) | |
labels[0, label_index] = 1 | |
# Generate image | |
with torch.no_grad(): | |
generated_image = sample( | |
model=loaded_model, | |
num_images=1, | |
timesteps=TIMESTEPS, | |
img_size=IMG_SIZE, | |
num_classes=NUM_CLASSES, | |
labels=labels, | |
device=device | |
) | |
# Convert to PIL Image | |
img_np = generated_image.squeeze(0).cpu().permute(1, 2, 0).numpy() | |
img_np = np.clip(img_np, 0, 1) # Ensure proper range | |
img_pil = Image.fromarray((img_np * 255).astype(np.uint8)) | |
return img_pil | |
# Model paths (update these for your deployment) | |
MODEL_DIR = "models" | |
MODEL_NAME = "diffusion_unet_xray.pth" # Update with your actual filename | |
model_path = os.path.join(MODEL_DIR, MODEL_NAME) | |
# Load model | |
print("Loading model...") | |
loaded_model = load_model(model_path, device) | |
print("Model loaded successfully!") | |
# Gradio interface | |
iface = gr.Interface( | |
fn=generate_image, | |
inputs=gr.Dropdown( | |
choices=["Pneumonia", "Pneumothorax"], | |
label="Select Condition", | |
value="Pneumonia" # Default value | |
), | |
outputs=gr.Image( | |
type="pil", | |
label="Generated X-ray Image" | |
), | |
title="Medical X-ray Image Generator", | |
description="Generate synthetic chest X-ray images using a diffusion model. Select a condition to generate.", | |
examples=[["Pneumonia"], ["Pneumothorax"]] | |
) | |
if __name__ == "__main__": | |
iface.launch(server_name="0.0.0.0", server_port=7860) |