File size: 3,162 Bytes
d5e5728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
import numpy as np
import math
import os

# Constants (update these to match your training config)
IMG_SIZE = 128
TIMESTEPS = 300
NUM_CLASSES = 2

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# SinusoidalPositionEmbeddings and UNet classes remain the same as your original code
# DiffusionModel class remains the same as your original code

# Load the trained model with improved error handling
def load_model(model_path, device):
    unet_model = UNet(num_classes=NUM_CLASSES).to(device)
    diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
    
    try:
        checkpoint = torch.load(model_path, map_location=device)
        # Handle both full model and state_dict loading
        if 'model_state_dict' in checkpoint:
            diffusion_model.model.load_state_dict(checkpoint['model_state_dict'])
        else:
            diffusion_model.model.load_state_dict(checkpoint)
        print(f"Successfully loaded model from {model_path}")
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Using randomly initialized weights")
    
    diffusion_model.eval()
    return diffusion_model

# Improved image generation function
def generate_image(label_str):
    label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
    try:
        label_index = label_map[label_str]
    except KeyError:
        raise gr.Error(f"Invalid label '{label_str}'. Please select either 'Pneumonia' or 'Pneumothorax'.")

    # Create one-hot encoded label
    labels = torch.zeros(1, NUM_CLASSES, device=device)
    labels[0, label_index] = 1

    # Generate image
    with torch.no_grad():
        generated_image = sample(
            model=loaded_model,
            num_images=1,
            timesteps=TIMESTEPS,
            img_size=IMG_SIZE,
            num_classes=NUM_CLASSES,
            labels=labels,
            device=device
        )

    # Convert to PIL Image
    img_np = generated_image.squeeze(0).cpu().permute(1, 2, 0).numpy()
    img_np = np.clip(img_np, 0, 1)  # Ensure proper range
    img_pil = Image.fromarray((img_np * 255).astype(np.uint8))
    
    return img_pil

# Model paths (update these for your deployment)
MODEL_DIR = "models"
MODEL_NAME = "diffusion_unet_xray.pth"  # Update with your actual filename
model_path = os.path.join(MODEL_DIR, MODEL_NAME)

# Load model
print("Loading model...")
loaded_model = load_model(model_path, device)
print("Model loaded successfully!")

# Gradio interface
iface = gr.Interface(
    fn=generate_image,
    inputs=gr.Dropdown(
        choices=["Pneumonia", "Pneumothorax"],
        label="Select Condition",
        value="Pneumonia"  # Default value
    ),
    outputs=gr.Image(
        type="pil",
        label="Generated X-ray Image"
    ),
    title="Medical X-ray Image Generator",
    description="Generate synthetic chest X-ray images using a diffusion model. Select a condition to generate.",
    examples=[["Pneumonia"], ["Pneumothorax"]]
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)