Spaces:
Running
Running
File size: 3,162 Bytes
d5e5728 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
import numpy as np
import math
import os
# Constants (update these to match your training config)
IMG_SIZE = 128
TIMESTEPS = 300
NUM_CLASSES = 2
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# SinusoidalPositionEmbeddings and UNet classes remain the same as your original code
# DiffusionModel class remains the same as your original code
# Load the trained model with improved error handling
def load_model(model_path, device):
unet_model = UNet(num_classes=NUM_CLASSES).to(device)
diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
try:
checkpoint = torch.load(model_path, map_location=device)
# Handle both full model and state_dict loading
if 'model_state_dict' in checkpoint:
diffusion_model.model.load_state_dict(checkpoint['model_state_dict'])
else:
diffusion_model.model.load_state_dict(checkpoint)
print(f"Successfully loaded model from {model_path}")
except Exception as e:
print(f"Error loading model: {e}")
print("Using randomly initialized weights")
diffusion_model.eval()
return diffusion_model
# Improved image generation function
def generate_image(label_str):
label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
try:
label_index = label_map[label_str]
except KeyError:
raise gr.Error(f"Invalid label '{label_str}'. Please select either 'Pneumonia' or 'Pneumothorax'.")
# Create one-hot encoded label
labels = torch.zeros(1, NUM_CLASSES, device=device)
labels[0, label_index] = 1
# Generate image
with torch.no_grad():
generated_image = sample(
model=loaded_model,
num_images=1,
timesteps=TIMESTEPS,
img_size=IMG_SIZE,
num_classes=NUM_CLASSES,
labels=labels,
device=device
)
# Convert to PIL Image
img_np = generated_image.squeeze(0).cpu().permute(1, 2, 0).numpy()
img_np = np.clip(img_np, 0, 1) # Ensure proper range
img_pil = Image.fromarray((img_np * 255).astype(np.uint8))
return img_pil
# Model paths (update these for your deployment)
MODEL_DIR = "models"
MODEL_NAME = "diffusion_unet_xray.pth" # Update with your actual filename
model_path = os.path.join(MODEL_DIR, MODEL_NAME)
# Load model
print("Loading model...")
loaded_model = load_model(model_path, device)
print("Model loaded successfully!")
# Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs=gr.Dropdown(
choices=["Pneumonia", "Pneumothorax"],
label="Select Condition",
value="Pneumonia" # Default value
),
outputs=gr.Image(
type="pil",
label="Generated X-ray Image"
),
title="Medical X-ray Image Generator",
description="Generate synthetic chest X-ray images using a diffusion model. Select a condition to generate.",
examples=[["Pneumonia"], ["Pneumothorax"]]
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860) |