Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,7 @@ from datetime import datetime
|
|
9 |
import sqlite3
|
10 |
import torch.nn as nn
|
11 |
import torchvision.models as models
|
|
|
12 |
|
13 |
app = Flask(__name__)
|
14 |
|
@@ -46,13 +47,11 @@ from densenet_withglam import get_model_with_attention
|
|
46 |
# β
Instantiate the model
|
47 |
model = get_model_with_attention('densenet169', num_classes=3) # Will have GLAM
|
48 |
model.load_state_dict(torch.load('densenet169_seed40_best.pt', map_location='cpu'))
|
49 |
-
# Load your trained weights
|
50 |
model.eval()
|
51 |
|
52 |
# β
Class Names
|
53 |
CLASS_NAMES = ["Advanced", "Early", "Normal"]
|
54 |
|
55 |
-
|
56 |
# β
Transformation for input images
|
57 |
transform = transforms.Compose([
|
58 |
transforms.Resize(256),
|
@@ -61,6 +60,46 @@ transform = transforms.Compose([
|
|
61 |
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
62 |
std=[0.229, 0.224, 0.225]),
|
63 |
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
@app.route('/')
|
65 |
def home():
|
66 |
"""Check that the API is working."""
|
@@ -75,9 +114,10 @@ def test_file():
|
|
75 |
else:
|
76 |
return "β Model file NOT found."
|
77 |
|
|
|
78 |
@app.route('/predict', methods=['POST'])
|
79 |
def predict():
|
80 |
-
"""Perform prediction
|
81 |
if 'file' not in request.files:
|
82 |
return jsonify({'error': 'No file uploaded'}), 400
|
83 |
|
@@ -96,22 +136,40 @@ def predict():
|
|
96 |
img = Image.open(uploaded_file_path).convert('RGB')
|
97 |
input_tensor = transform(img).unsqueeze(0)
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
102 |
|
103 |
-
|
104 |
class_index = np.argmax(probabilities)
|
105 |
result = CLASS_NAMES[class_index]
|
106 |
confidence = float(probabilities[class_index])
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
# β
Save results to SQLite
|
109 |
conn = sqlite3.connect(DB_PATH)
|
110 |
cursor = conn.cursor()
|
111 |
cursor.execute("""
|
112 |
INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, timestamp)
|
113 |
VALUES (?, ?, ?, ?, ?)
|
114 |
-
""", (uploaded_filename, result, confidence,
|
115 |
conn.commit()
|
116 |
conn.close()
|
117 |
|
@@ -121,13 +179,14 @@ def predict():
|
|
121 |
'normal_probability': float(probabilities[0]),
|
122 |
'early_glaucoma_probability': float(probabilities[1]),
|
123 |
'advanced_glaucoma_probability': float(probabilities[2]),
|
124 |
-
'gradcam_image':
|
125 |
'image_filename': uploaded_filename
|
126 |
})
|
127 |
|
128 |
except Exception as e:
|
129 |
return jsonify({'error': str(e)}), 500
|
130 |
|
|
|
131 |
@app.route('/results', methods=['GET'])
|
132 |
def results():
|
133 |
"""List all results from the SQLite database."""
|
@@ -150,15 +209,17 @@ def results():
|
|
150 |
|
151 |
return jsonify(results_list)
|
152 |
|
|
|
153 |
@app.route('/gradcam/<filename>')
|
154 |
def get_gradcam(filename):
|
155 |
-
"""Serve the Grad-CAM overlay image
|
156 |
filepath = os.path.join(OUTPUT_DIR, filename)
|
157 |
if os.path.exists(filepath):
|
158 |
return send_file(filepath, mimetype='image/png')
|
159 |
else:
|
160 |
return jsonify({'error': 'File not found'}), 404
|
161 |
|
|
|
162 |
@app.route('/image/<filename>')
|
163 |
def get_image(filename):
|
164 |
"""Serve the original uploaded image."""
|
@@ -168,5 +229,6 @@ def get_image(filename):
|
|
168 |
else:
|
169 |
return jsonify({'error': 'File not found'}), 404
|
170 |
|
|
|
171 |
if __name__ == '__main__':
|
172 |
app.run(host='0.0.0.0', port=7860)
|
|
|
9 |
import sqlite3
|
10 |
import torch.nn as nn
|
11 |
import torchvision.models as models
|
12 |
+
import cv2
|
13 |
|
14 |
app = Flask(__name__)
|
15 |
|
|
|
47 |
# β
Instantiate the model
|
48 |
model = get_model_with_attention('densenet169', num_classes=3) # Will have GLAM
|
49 |
model.load_state_dict(torch.load('densenet169_seed40_best.pt', map_location='cpu'))
|
|
|
50 |
model.eval()
|
51 |
|
52 |
# β
Class Names
|
53 |
CLASS_NAMES = ["Advanced", "Early", "Normal"]
|
54 |
|
|
|
55 |
# β
Transformation for input images
|
56 |
transform = transforms.Compose([
|
57 |
transforms.Resize(256),
|
|
|
60 |
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
61 |
std=[0.229, 0.224, 0.225]),
|
62 |
])
|
63 |
+
|
64 |
+
# =========================
|
65 |
+
# GRAD-CAM IMPLEMENTATION
|
66 |
+
# =========================
|
67 |
+
class GradCAM:
|
68 |
+
"""Grad-CAM for the target layer."""
|
69 |
+
def __init__(self, model, target_layer_name):
|
70 |
+
self.model = model
|
71 |
+
self.target_layer_name = target_layer_name
|
72 |
+
self.activations = None
|
73 |
+
self.gradients = None
|
74 |
+
self._register_hooks()
|
75 |
+
|
76 |
+
def _register_hooks(self):
|
77 |
+
"""Register forward and backward hooks."""
|
78 |
+
for name, module in self.model.named_modules():
|
79 |
+
if name == self.target_layer_name:
|
80 |
+
module.register_forward_hook(self._forward_hook)
|
81 |
+
module.register_full_backward_hook(self._backward_hook)
|
82 |
+
|
83 |
+
def _forward_hook(self, module, input, output):
|
84 |
+
"""Save activations."""
|
85 |
+
self.activations = output
|
86 |
+
|
87 |
+
def _backward_hook(self, module, grad_in, grad_out):
|
88 |
+
"""Save gradients."""
|
89 |
+
self.gradients = grad_out[0]
|
90 |
+
|
91 |
+
def generate(self, class_index):
|
92 |
+
"""Generate the Grad-CAM."""
|
93 |
+
if self.activations is None or self.gradients is None:
|
94 |
+
raise ValueError("Must run forward and backward passes first.")
|
95 |
+
weights = self.gradients.mean(dim=(2, 3), keepdim=True)
|
96 |
+
cam = (weights * self.activations).sum(dim=1, keepdim=True)
|
97 |
+
cam = F.relu(cam)
|
98 |
+
cam = cam.squeeze().cpu().numpy()
|
99 |
+
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
|
100 |
+
return cam
|
101 |
+
|
102 |
+
|
103 |
@app.route('/')
|
104 |
def home():
|
105 |
"""Check that the API is working."""
|
|
|
114 |
else:
|
115 |
return "β Model file NOT found."
|
116 |
|
117 |
+
|
118 |
@app.route('/predict', methods=['POST'])
|
119 |
def predict():
|
120 |
+
"""Perform prediction and save results (including Grad-CAM) to the database."""
|
121 |
if 'file' not in request.files:
|
122 |
return jsonify({'error': 'No file uploaded'}), 400
|
123 |
|
|
|
136 |
img = Image.open(uploaded_file_path).convert('RGB')
|
137 |
input_tensor = transform(img).unsqueeze(0)
|
138 |
|
139 |
+
# Grad-CAM setup
|
140 |
+
target_layer_name = "features.2.global_spatial_conv"
|
141 |
+
gradcam = GradCAM(model, target_layer_name)
|
142 |
+
|
143 |
+
# Forward pass
|
144 |
+
input_tensor.requires_grad = True
|
145 |
+
output = model(input_tensor)
|
146 |
|
147 |
+
probabilities = F.softmax(output, dim=1).cpu().numpy()[0]
|
148 |
class_index = np.argmax(probabilities)
|
149 |
result = CLASS_NAMES[class_index]
|
150 |
confidence = float(probabilities[class_index])
|
151 |
|
152 |
+
# Backward pass for Grad-CAM
|
153 |
+
model.zero_grad()
|
154 |
+
output[0, class_index].backward()
|
155 |
+
cam = gradcam.generate(class_index)
|
156 |
+
|
157 |
+
# β
Create overlay
|
158 |
+
original_img = np.asarray(img.resize((224, 224)))
|
159 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
|
160 |
+
overlay = cv2.addWeighted(original_img, 0.6, heatmap, 0.4, 0)
|
161 |
+
|
162 |
+
gradcam_filename = f"gradcam_{timestamp}.png"
|
163 |
+
gradcam_file_path = os.path.join(OUTPUT_DIR, gradcam_filename)
|
164 |
+
cv2.imwrite(gradcam_file_path, overlay)
|
165 |
+
|
166 |
# β
Save results to SQLite
|
167 |
conn = sqlite3.connect(DB_PATH)
|
168 |
cursor = conn.cursor()
|
169 |
cursor.execute("""
|
170 |
INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, timestamp)
|
171 |
VALUES (?, ?, ?, ?, ?)
|
172 |
+
""", (uploaded_filename, result, confidence, gradcam_filename, datetime.now().isoformat()))
|
173 |
conn.commit()
|
174 |
conn.close()
|
175 |
|
|
|
179 |
'normal_probability': float(probabilities[0]),
|
180 |
'early_glaucoma_probability': float(probabilities[1]),
|
181 |
'advanced_glaucoma_probability': float(probabilities[2]),
|
182 |
+
'gradcam_image': gradcam_filename,
|
183 |
'image_filename': uploaded_filename
|
184 |
})
|
185 |
|
186 |
except Exception as e:
|
187 |
return jsonify({'error': str(e)}), 500
|
188 |
|
189 |
+
|
190 |
@app.route('/results', methods=['GET'])
|
191 |
def results():
|
192 |
"""List all results from the SQLite database."""
|
|
|
209 |
|
210 |
return jsonify(results_list)
|
211 |
|
212 |
+
|
213 |
@app.route('/gradcam/<filename>')
|
214 |
def get_gradcam(filename):
|
215 |
+
"""Serve the Grad-CAM overlay image."""
|
216 |
filepath = os.path.join(OUTPUT_DIR, filename)
|
217 |
if os.path.exists(filepath):
|
218 |
return send_file(filepath, mimetype='image/png')
|
219 |
else:
|
220 |
return jsonify({'error': 'File not found'}), 404
|
221 |
|
222 |
+
|
223 |
@app.route('/image/<filename>')
|
224 |
def get_image(filename):
|
225 |
"""Serve the original uploaded image."""
|
|
|
229 |
else:
|
230 |
return jsonify({'error': 'File not found'}), 404
|
231 |
|
232 |
+
|
233 |
if __name__ == '__main__':
|
234 |
app.run(host='0.0.0.0', port=7860)
|