Vedansh-7's picture
Upload app.py
d5e5728
raw
history blame
3.16 kB
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
import numpy as np
import math
import os
# Constants (update these to match your training config)
IMG_SIZE = 128
TIMESTEPS = 300
NUM_CLASSES = 2
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# SinusoidalPositionEmbeddings and UNet classes remain the same as your original code
# DiffusionModel class remains the same as your original code
# Load the trained model with improved error handling
def load_model(model_path, device):
unet_model = UNet(num_classes=NUM_CLASSES).to(device)
diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
try:
checkpoint = torch.load(model_path, map_location=device)
# Handle both full model and state_dict loading
if 'model_state_dict' in checkpoint:
diffusion_model.model.load_state_dict(checkpoint['model_state_dict'])
else:
diffusion_model.model.load_state_dict(checkpoint)
print(f"Successfully loaded model from {model_path}")
except Exception as e:
print(f"Error loading model: {e}")
print("Using randomly initialized weights")
diffusion_model.eval()
return diffusion_model
# Improved image generation function
def generate_image(label_str):
label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
try:
label_index = label_map[label_str]
except KeyError:
raise gr.Error(f"Invalid label '{label_str}'. Please select either 'Pneumonia' or 'Pneumothorax'.")
# Create one-hot encoded label
labels = torch.zeros(1, NUM_CLASSES, device=device)
labels[0, label_index] = 1
# Generate image
with torch.no_grad():
generated_image = sample(
model=loaded_model,
num_images=1,
timesteps=TIMESTEPS,
img_size=IMG_SIZE,
num_classes=NUM_CLASSES,
labels=labels,
device=device
)
# Convert to PIL Image
img_np = generated_image.squeeze(0).cpu().permute(1, 2, 0).numpy()
img_np = np.clip(img_np, 0, 1) # Ensure proper range
img_pil = Image.fromarray((img_np * 255).astype(np.uint8))
return img_pil
# Model paths (update these for your deployment)
MODEL_DIR = "models"
MODEL_NAME = "diffusion_unet_xray.pth" # Update with your actual filename
model_path = os.path.join(MODEL_DIR, MODEL_NAME)
# Load model
print("Loading model...")
loaded_model = load_model(model_path, device)
print("Model loaded successfully!")
# Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs=gr.Dropdown(
choices=["Pneumonia", "Pneumothorax"],
label="Select Condition",
value="Pneumonia" # Default value
),
outputs=gr.Image(
type="pil",
label="Generated X-ray Image"
),
title="Medical X-ray Image Generator",
description="Generate synthetic chest X-ray images using a diffusion model. Select a condition to generate.",
examples=[["Pneumonia"], ["Pneumothorax"]]
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)