Vedansh-7 commited on
Commit
d5e5728
·
1 Parent(s): 7531575

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import numpy as np
6
+ import math
7
+ import os
8
+
9
+ # Constants (update these to match your training config)
10
+ IMG_SIZE = 128
11
+ TIMESTEPS = 300
12
+ NUM_CLASSES = 2
13
+
14
+ # Define the device
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # SinusoidalPositionEmbeddings and UNet classes remain the same as your original code
18
+ # DiffusionModel class remains the same as your original code
19
+
20
+ # Load the trained model with improved error handling
21
+ def load_model(model_path, device):
22
+ unet_model = UNet(num_classes=NUM_CLASSES).to(device)
23
+ diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
24
+
25
+ try:
26
+ checkpoint = torch.load(model_path, map_location=device)
27
+ # Handle both full model and state_dict loading
28
+ if 'model_state_dict' in checkpoint:
29
+ diffusion_model.model.load_state_dict(checkpoint['model_state_dict'])
30
+ else:
31
+ diffusion_model.model.load_state_dict(checkpoint)
32
+ print(f"Successfully loaded model from {model_path}")
33
+ except Exception as e:
34
+ print(f"Error loading model: {e}")
35
+ print("Using randomly initialized weights")
36
+
37
+ diffusion_model.eval()
38
+ return diffusion_model
39
+
40
+ # Improved image generation function
41
+ def generate_image(label_str):
42
+ label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
43
+ try:
44
+ label_index = label_map[label_str]
45
+ except KeyError:
46
+ raise gr.Error(f"Invalid label '{label_str}'. Please select either 'Pneumonia' or 'Pneumothorax'.")
47
+
48
+ # Create one-hot encoded label
49
+ labels = torch.zeros(1, NUM_CLASSES, device=device)
50
+ labels[0, label_index] = 1
51
+
52
+ # Generate image
53
+ with torch.no_grad():
54
+ generated_image = sample(
55
+ model=loaded_model,
56
+ num_images=1,
57
+ timesteps=TIMESTEPS,
58
+ img_size=IMG_SIZE,
59
+ num_classes=NUM_CLASSES,
60
+ labels=labels,
61
+ device=device
62
+ )
63
+
64
+ # Convert to PIL Image
65
+ img_np = generated_image.squeeze(0).cpu().permute(1, 2, 0).numpy()
66
+ img_np = np.clip(img_np, 0, 1) # Ensure proper range
67
+ img_pil = Image.fromarray((img_np * 255).astype(np.uint8))
68
+
69
+ return img_pil
70
+
71
+ # Model paths (update these for your deployment)
72
+ MODEL_DIR = "models"
73
+ MODEL_NAME = "diffusion_unet_xray.pth" # Update with your actual filename
74
+ model_path = os.path.join(MODEL_DIR, MODEL_NAME)
75
+
76
+ # Load model
77
+ print("Loading model...")
78
+ loaded_model = load_model(model_path, device)
79
+ print("Model loaded successfully!")
80
+
81
+ # Gradio interface
82
+ iface = gr.Interface(
83
+ fn=generate_image,
84
+ inputs=gr.Dropdown(
85
+ choices=["Pneumonia", "Pneumothorax"],
86
+ label="Select Condition",
87
+ value="Pneumonia" # Default value
88
+ ),
89
+ outputs=gr.Image(
90
+ type="pil",
91
+ label="Generated X-ray Image"
92
+ ),
93
+ title="Medical X-ray Image Generator",
94
+ description="Generate synthetic chest X-ray images using a diffusion model. Select a condition to generate.",
95
+ examples=[["Pneumonia"], ["Pneumothorax"]]
96
+ )
97
+
98
+ if __name__ == "__main__":
99
+ iface.launch(server_name="0.0.0.0", server_port=7860)