Update app.py
Browse files
app.py
CHANGED
@@ -9,23 +9,18 @@ from datetime import datetime
|
|
9 |
import sqlite3
|
10 |
import torch.nn as nn
|
11 |
import cv2
|
|
|
12 |
|
13 |
-
#
|
14 |
from pytorch_grad_cam import GradCAMPlusPlus
|
15 |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
16 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
17 |
|
18 |
-
# β
Import Hugging Face-style GLAM EfficientNet model
|
19 |
-
from glam_efficientnet_model import GLAMEfficientNetForClassification, GLAMEfficientNetConfig
|
20 |
-
from glam_module import GLAM
|
21 |
-
from swin_module import SwinWindowAttention
|
22 |
-
|
23 |
app = Flask(__name__)
|
24 |
|
25 |
-
# β
Directory and database
|
26 |
OUTPUT_DIR = '/tmp/results'
|
27 |
-
|
28 |
-
os.makedirs(OUTPUT_DIR)
|
29 |
|
30 |
DB_PATH = os.path.join(OUTPUT_DIR, 'results.db')
|
31 |
|
@@ -51,18 +46,30 @@ def init_db():
|
|
51 |
|
52 |
init_db()
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
)
|
59 |
-
|
|
|
|
|
60 |
model.eval()
|
61 |
|
62 |
# β
Class Names
|
63 |
CLASS_NAMES = ["Advanced", "Early", "Normal"]
|
64 |
|
65 |
-
# β
|
66 |
transform = transforms.Compose([
|
67 |
transforms.Resize(256),
|
68 |
transforms.CenterCrop(224),
|
@@ -74,27 +81,26 @@ transform = transforms.Compose([
|
|
74 |
@app.route('/')
|
75 |
def home():
|
76 |
"""Check that the API is working."""
|
77 |
-
return "Glaucoma Detection Flask API (
|
78 |
|
79 |
@app.route("/test_file")
|
80 |
def test_file():
|
81 |
"""Check if the .pt model file is present and readable."""
|
82 |
-
filepath = "
|
83 |
if os.path.exists(filepath):
|
84 |
return f"β
Model file found at: {filepath}"
|
85 |
else:
|
86 |
return "β Model file NOT found."
|
87 |
|
88 |
-
|
89 |
@app.route('/predict', methods=['POST'])
|
90 |
def predict():
|
91 |
"""Perform prediction and save results (including Grad-CAM++) to the database."""
|
92 |
if 'file' not in request.files:
|
93 |
-
return jsonify({'error': 'No file uploaded'}), 400
|
94 |
|
95 |
uploaded_file = request.files['file']
|
96 |
if uploaded_file.filename == '':
|
97 |
-
return jsonify({'error': 'No file selected'}), 400
|
98 |
|
99 |
try:
|
100 |
# β
Save the uploaded image
|
@@ -107,20 +113,20 @@ def predict():
|
|
107 |
img = Image.open(uploaded_file_path).convert('RGB')
|
108 |
input_tensor = transform(img).unsqueeze(0)
|
109 |
|
110 |
-
#
|
111 |
-
|
112 |
-
|
|
|
113 |
class_index = np.argmax(probabilities)
|
114 |
result = CLASS_NAMES[class_index]
|
115 |
confidence = float(probabilities[class_index])
|
116 |
|
117 |
# β
Grad-CAM++ setup
|
118 |
-
|
119 |
-
target_layer = dict(model.features.named_modules())["features.7"] # β
Adjust as needed
|
120 |
cam_model = GradCAMPlusPlus(model=model, target_layers=[target_layer])
|
121 |
|
122 |
-
|
123 |
-
|
124 |
|
125 |
# β
Create RGB overlay
|
126 |
original_img = np.asarray(img.resize((224, 224)), dtype=np.float32) / 255.0
|
@@ -196,7 +202,7 @@ def get_gradcam(filename):
|
|
196 |
if os.path.exists(filepath):
|
197 |
return send_file(filepath, mimetype='image/png')
|
198 |
else:
|
199 |
-
return jsonify({'error': 'File not found'}), 404
|
200 |
|
201 |
|
202 |
@app.route('/image/<filename>')
|
@@ -206,8 +212,9 @@ def get_image(filename):
|
|
206 |
if os.path.exists(filepath):
|
207 |
return send_file(filepath, mimetype='image/png')
|
208 |
else:
|
209 |
-
return jsonify({'error': 'File not found'}), 404
|
210 |
|
211 |
|
212 |
if __name__ == '__main__':
|
213 |
app.run(host='0.0.0.0', port=7860)
|
|
|
|
9 |
import sqlite3
|
10 |
import torch.nn as nn
|
11 |
import cv2
|
12 |
+
import json
|
13 |
|
14 |
+
# Grad-CAM++ imports
|
15 |
from pytorch_grad_cam import GradCAMPlusPlus
|
16 |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
17 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
18 |
|
|
|
|
|
|
|
|
|
|
|
19 |
app = Flask(__name__)
|
20 |
|
21 |
+
# β
Directory and database
|
22 |
OUTPUT_DIR = '/tmp/results'
|
23 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
24 |
|
25 |
DB_PATH = os.path.join(OUTPUT_DIR, 'results.db')
|
26 |
|
|
|
46 |
|
47 |
init_db()
|
48 |
|
49 |
+
|
50 |
+
# β
Import your EfficientNetB0_TransformerGLAM model
|
51 |
+
from efficientnet_transformer_glam import EfficientNetb0_TransformerGLAM # Ensure this is in the path
|
52 |
+
|
53 |
+
|
54 |
+
# β
Instantiate the model
|
55 |
+
model = EfficientNetb0_TransformerGLAM(
|
56 |
+
num_classes=3,
|
57 |
+
embed_dim=512,
|
58 |
+
num_heads=8,
|
59 |
+
mlp_dim=512,
|
60 |
+
dropout=0.5,
|
61 |
+
window_size=7,
|
62 |
+
reduction_ratio=8
|
63 |
)
|
64 |
+
|
65 |
+
# β
Load the trained checkpoint
|
66 |
+
model.load_state_dict(torch.load('densenet169_seed40_best.pt', map_location='cpu'))
|
67 |
model.eval()
|
68 |
|
69 |
# β
Class Names
|
70 |
CLASS_NAMES = ["Advanced", "Early", "Normal"]
|
71 |
|
72 |
+
# β
Transforms
|
73 |
transform = transforms.Compose([
|
74 |
transforms.Resize(256),
|
75 |
transforms.CenterCrop(224),
|
|
|
81 |
@app.route('/')
|
82 |
def home():
|
83 |
"""Check that the API is working."""
|
84 |
+
return "Glaucoma Detection Flask API (EfficientNetB0_TransformerGLAM Model) is running!"
|
85 |
|
86 |
@app.route("/test_file")
|
87 |
def test_file():
|
88 |
"""Check if the .pt model file is present and readable."""
|
89 |
+
filepath = "densenet169_seed40_best.pt"
|
90 |
if os.path.exists(filepath):
|
91 |
return f"β
Model file found at: {filepath}"
|
92 |
else:
|
93 |
return "β Model file NOT found."
|
94 |
|
|
|
95 |
@app.route('/predict', methods=['POST'])
|
96 |
def predict():
|
97 |
"""Perform prediction and save results (including Grad-CAM++) to the database."""
|
98 |
if 'file' not in request.files:
|
99 |
+
return jsonify({'error': 'No file uploaded.'}), 400
|
100 |
|
101 |
uploaded_file = request.files['file']
|
102 |
if uploaded_file.filename == '':
|
103 |
+
return jsonify({'error': 'No file selected.'}), 400
|
104 |
|
105 |
try:
|
106 |
# β
Save the uploaded image
|
|
|
113 |
img = Image.open(uploaded_file_path).convert('RGB')
|
114 |
input_tensor = transform(img).unsqueeze(0)
|
115 |
|
116 |
+
# Model Inference
|
117 |
+
with torch.no_grad():
|
118 |
+
output = model(input_tensor)
|
119 |
+
probabilities = F.softmax(output, dim=1).cpu().numpy()[0]
|
120 |
class_index = np.argmax(probabilities)
|
121 |
result = CLASS_NAMES[class_index]
|
122 |
confidence = float(probabilities[class_index])
|
123 |
|
124 |
# β
Grad-CAM++ setup
|
125 |
+
target_layer = model.feature_extractor[-1] # Final block of EfficientNet feature extractor
|
|
|
126 |
cam_model = GradCAMPlusPlus(model=model, target_layers=[target_layer])
|
127 |
|
128 |
+
cam_output = cam_model(input_tensor=input_tensor,
|
129 |
+
targets=[ClassifierOutputTarget(class_index)])[0]
|
130 |
|
131 |
# β
Create RGB overlay
|
132 |
original_img = np.asarray(img.resize((224, 224)), dtype=np.float32) / 255.0
|
|
|
202 |
if os.path.exists(filepath):
|
203 |
return send_file(filepath, mimetype='image/png')
|
204 |
else:
|
205 |
+
return jsonify({'error': 'File not found.'}), 404
|
206 |
|
207 |
|
208 |
@app.route('/image/<filename>')
|
|
|
212 |
if os.path.exists(filepath):
|
213 |
return send_file(filepath, mimetype='image/png')
|
214 |
else:
|
215 |
+
return jsonify({'error': 'File not found.'}), 404
|
216 |
|
217 |
|
218 |
if __name__ == '__main__':
|
219 |
app.run(host='0.0.0.0', port=7860)
|
220 |
+
|