dhruv2842 commited on
Commit
3c3aa1b
Β·
verified Β·
1 Parent(s): 3fd2a9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -212
app.py CHANGED
@@ -1,212 +1,213 @@
1
- from flask import Flask, request, jsonify, send_file
2
- from PIL import Image
3
- import torch
4
- import torch.nn.functional as F
5
- from torchvision import transforms
6
- import os
7
- import numpy as np
8
- from datetime import datetime
9
- import sqlite3
10
- import torch.nn as nn
11
- import cv2
12
-
13
- # βœ… New Grad-CAM++ imports
14
- from pytorch_grad_cam import GradCAMPlusPlus
15
- from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
16
- from pytorch_grad_cam.utils.image import show_cam_on_image
17
-
18
- # βœ… Import Hugging Face-style GLAM EfficientNet model
19
- from glam_efficientnet_model import GLAMEfficientNetForClassification, GLAMEfficientNetConfig
20
-
21
- app = Flask(__name__)
22
-
23
- # βœ… Directory and database path
24
- OUTPUT_DIR = '/tmp/results'
25
- if not os.path.exists(OUTPUT_DIR):
26
- os.makedirs(OUTPUT_DIR)
27
-
28
- DB_PATH = os.path.join(OUTPUT_DIR, 'results.db')
29
-
30
-
31
- def init_db():
32
- """Initialize SQLite database for storing results."""
33
- conn = sqlite3.connect(DB_PATH)
34
- cursor = conn.cursor()
35
- cursor.execute("""
36
- CREATE TABLE IF NOT EXISTS results (
37
- id INTEGER PRIMARY KEY AUTOINCREMENT,
38
- image_filename TEXT,
39
- prediction TEXT,
40
- confidence REAL,
41
- gradcam_filename TEXT,
42
- gradcam_gray_filename TEXT,
43
- timestamp TEXT
44
- )
45
- """)
46
- conn.commit()
47
- conn.close()
48
-
49
-
50
- init_db()
51
-
52
- # βœ… Load GLAM EfficientNet Model
53
- config = GLAMEfficientNetConfig()
54
- model = GLAMEfficientNetForClassification(config)
55
- model.load_state_dict(torch.load('efficientnet_glam_best.pt', map_location='cpu'))
56
- model.eval()
57
-
58
- # βœ… Class Names
59
- CLASS_NAMES = ["Advanced", "Early", "Normal"]
60
-
61
- # βœ… Transformation for input images
62
- transform = transforms.Compose([
63
- transforms.Resize(256),
64
- transforms.CenterCrop(224),
65
- transforms.ToTensor(),
66
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
67
- std=[0.229, 0.224, 0.225]),
68
- ])
69
-
70
-
71
- @app.route('/')
72
- def home():
73
- """Check that the API is working."""
74
- return "Glaucoma Detection Flask API (EfficientNet + GLAM) is running!"
75
-
76
-
77
- @app.route("/test_file")
78
- def test_file():
79
- """Check if the .pt model file is present and readable."""
80
- filepath = "efficientnet_glam_best.pt"
81
- if os.path.exists(filepath):
82
- return f"βœ… Model file found at: {filepath}"
83
- else:
84
- return "❌ Model file NOT found."
85
-
86
-
87
- @app.route('/predict', methods=['POST'])
88
- def predict():
89
- """Perform prediction and save results (including Grad-CAM++) to the database."""
90
- if 'file' not in request.files:
91
- return jsonify({'error': 'No file uploaded'}), 400
92
-
93
- uploaded_file = request.files['file']
94
- if uploaded_file.filename == '':
95
- return jsonify({'error': 'No file selected'}), 400
96
-
97
- try:
98
- # βœ… Save the uploaded image
99
- timestamp = int(datetime.now().timestamp())
100
- uploaded_filename = f"uploaded_{timestamp}.png"
101
- uploaded_file_path = os.path.join(OUTPUT_DIR, uploaded_filename)
102
- uploaded_file.save(uploaded_file_path)
103
-
104
- # βœ… Perform prediction
105
- img = Image.open(uploaded_file_path).convert('RGB')
106
- input_tensor = transform(img).unsqueeze(0)
107
-
108
- # βœ… Get prediction
109
- output = model(input_tensor) # Dict with "logits"
110
- probabilities = F.softmax(output["logits"], dim=1).cpu().detach().numpy()[0]
111
- class_index = np.argmax(probabilities)
112
- result = CLASS_NAMES[class_index]
113
- confidence = float(probabilities[class_index])
114
-
115
- # βœ… Grad-CAM++ setup
116
- # IMPORTANT: Choose the target layer from the GLAM EfficientNet model.
117
- # For example, use the final convolutional block:
118
- target_layer = model.features[-1]
119
- cam_model = GradCAMPlusPlus(model=model, target_layers=[target_layer])
120
-
121
- # βœ… Get Grad-CAM++ map
122
- cam_output = cam_model(input_tensor=input_tensor, targets=[ClassifierOutputTarget(class_index)])[0]
123
-
124
- # βœ… Create RGB overlay
125
- original_img = np.asarray(img.resize((224, 224)), dtype=np.float32) / 255.0
126
- overlay = show_cam_on_image(original_img, cam_output, use_rgb=True)
127
-
128
- # βœ… Create grayscale version
129
- cam_normalized = np.uint8(255 * cam_output)
130
-
131
- # βœ… Save overlay
132
- gradcam_filename = f"gradcam_{timestamp}.png"
133
- gradcam_file_path = os.path.join(OUTPUT_DIR, gradcam_filename)
134
- cv2.imwrite(gradcam_file_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
135
-
136
- # βœ… Save grayscale
137
- gray_filename = f"gradcam_gray_{timestamp}.png"
138
- gray_file_path = os.path.join(OUTPUT_DIR, gray_filename)
139
- cv2.imwrite(gray_file_path, cam_normalized)
140
-
141
- # βœ… Save results to database
142
- conn = sqlite3.connect(DB_PATH)
143
- cursor = conn.cursor()
144
- cursor.execute("""
145
- INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, gradcam_gray_filename, timestamp)
146
- VALUES (?, ?, ?, ?, ?, ?)
147
- """, (uploaded_filename, result, confidence, gradcam_filename, gray_filename, datetime.now().isoformat()))
148
- conn.commit()
149
- conn.close()
150
-
151
- # βœ… Return results
152
- return jsonify({
153
- 'prediction': result,
154
- 'confidence': confidence,
155
- 'normal_probability': float(probabilities[0]),
156
- 'early_glaucoma_probability': float(probabilities[1]),
157
- 'advanced_glaucoma_probability': float(probabilities[2]),
158
- 'gradcam_image': gradcam_filename,
159
- 'gradcam_gray_image': gray_filename,
160
- 'image_filename': uploaded_filename
161
- })
162
-
163
- except Exception as e:
164
- return jsonify({'error': str(e)}), 500
165
-
166
-
167
- @app.route('/results', methods=['GET'])
168
- def results():
169
- """List all results from the SQLite database."""
170
- conn = sqlite3.connect(DB_PATH)
171
- cursor = conn.cursor()
172
- cursor.execute("SELECT * FROM results ORDER BY timestamp DESC")
173
- results_data = cursor.fetchall()
174
- conn.close()
175
-
176
- results_list = []
177
- for record in results_data:
178
- results_list.append({
179
- 'id': record[0],
180
- 'image_filename': record[1],
181
- 'prediction': record[2],
182
- 'confidence': record[3],
183
- 'gradcam_filename': record[4],
184
- 'gradcam_gray_filename': record[5],
185
- 'timestamp': record[6]
186
- })
187
-
188
- return jsonify(results_list)
189
-
190
-
191
- @app.route('/gradcam/<filename>')
192
- def get_gradcam(filename):
193
- """Serve the Grad-CAM overlay image."""
194
- filepath = os.path.join(OUTPUT_DIR, filename)
195
- if os.path.exists(filepath):
196
- return send_file(filepath, mimetype='image/png')
197
- else:
198
- return jsonify({'error': 'File not found'}), 404
199
-
200
-
201
- @app.route('/image/<filename>')
202
- def get_image(filename):
203
- """Serve the original uploaded image."""
204
- filepath = os.path.join(OUTPUT_DIR, filename)
205
- if os.path.exists(filepath):
206
- return send_file(filepath, mimetype='image/png')
207
- else:
208
- return jsonify({'error': 'File not found'}), 404
209
-
210
-
211
- if __name__ == '__main__':
212
- app.run(host='0.0.0.0', port=7860)
 
 
1
+ from flask import Flask, request, jsonify, send_file
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision import transforms
6
+ import os
7
+ import numpy as np
8
+ from datetime import datetime
9
+ import sqlite3
10
+ import torch.nn as nn
11
+ import cv2
12
+
13
+ # βœ… New Grad-CAM++ imports
14
+ from pytorch_grad_cam import GradCAMPlusPlus
15
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
16
+ from pytorch_grad_cam.utils.image import show_cam_on_image
17
+
18
+ # βœ… Import Hugging Face-style GLAM EfficientNet model
19
+ from glam_efficientnet_model import GLAMEfficientNetForClassification, GLAMEfficientNetConfig
20
+ from glam_module import GLAM
21
+ from swin_module import SwinWindowAttention
22
+
23
+ app = Flask(__name__)
24
+
25
+ # βœ… Directory and database path
26
+ OUTPUT_DIR = '/tmp/results'
27
+ if not os.path.exists(OUTPUT_DIR):
28
+ os.makedirs(OUTPUT_DIR)
29
+
30
+ DB_PATH = os.path.join(OUTPUT_DIR, 'results.db')
31
+
32
+
33
+ def init_db():
34
+ """Initialize SQLite database for storing results."""
35
+ conn = sqlite3.connect(DB_PATH)
36
+ cursor = conn.cursor()
37
+ cursor.execute("""
38
+ CREATE TABLE IF NOT EXISTS results (
39
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
40
+ image_filename TEXT,
41
+ prediction TEXT,
42
+ confidence REAL,
43
+ gradcam_filename TEXT,
44
+ gradcam_gray_filename TEXT,
45
+ timestamp TEXT
46
+ )
47
+ """)
48
+ conn.commit()
49
+ conn.close()
50
+
51
+
52
+ init_db()
53
+
54
+ # βœ… Load GLAM EfficientNet Model
55
+ config = GLAMEfficientNetConfig()
56
+ model = GLAMEfficientNetForClassification(
57
+ config, glam_module_cls=GLAM, swin_module_cls=SwinWindowAttention
58
+ )
59
+ model.load_state_dict(torch.load('efficientnet_glam_best.pt', map_location='cpu'))
60
+ model.eval()
61
+
62
+ # βœ… Class Names
63
+ CLASS_NAMES = ["Advanced", "Early", "Normal"]
64
+
65
+ # βœ… Transformation for input images
66
+ transform = transforms.Compose([
67
+ transforms.Resize(256),
68
+ transforms.CenterCrop(224),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
71
+ std=[0.229, 0.224, 0.225]),
72
+ ])
73
+
74
+ @app.route('/')
75
+ def home():
76
+ """Check that the API is working."""
77
+ return "Glaucoma Detection Flask API (GLAM EfficientNet) is running!"
78
+
79
+ @app.route("/test_file")
80
+ def test_file():
81
+ """Check if the .pt model file is present and readable."""
82
+ filepath = "efficientnet_glam_best.pt"
83
+ if os.path.exists(filepath):
84
+ return f"βœ… Model file found at: {filepath}"
85
+ else:
86
+ return "❌ Model file NOT found."
87
+
88
+
89
+ @app.route('/predict', methods=['POST'])
90
+ def predict():
91
+ """Perform prediction and save results (including Grad-CAM++) to the database."""
92
+ if 'file' not in request.files:
93
+ return jsonify({'error': 'No file uploaded'}), 400
94
+
95
+ uploaded_file = request.files['file']
96
+ if uploaded_file.filename == '':
97
+ return jsonify({'error': 'No file selected'}), 400
98
+
99
+ try:
100
+ # βœ… Save the uploaded image
101
+ timestamp = int(datetime.now().timestamp())
102
+ uploaded_filename = f"uploaded_{timestamp}.png"
103
+ uploaded_file_path = os.path.join(OUTPUT_DIR, uploaded_filename)
104
+ uploaded_file.save(uploaded_file_path)
105
+
106
+ # βœ… Perform prediction
107
+ img = Image.open(uploaded_file_path).convert('RGB')
108
+ input_tensor = transform(img).unsqueeze(0)
109
+
110
+ # βœ… Get prediction
111
+ output = model(input_tensor) # Dict with "logits"
112
+ probabilities = F.softmax(output["logits"], dim=1).cpu().detach().numpy()[0]
113
+ class_index = np.argmax(probabilities)
114
+ result = CLASS_NAMES[class_index]
115
+ confidence = float(probabilities[class_index])
116
+
117
+ # βœ… Grad-CAM++ setup
118
+ # Target the final convolutional output. In GLAM EfficientNet, this is `model.features`
119
+ target_layer = dict(model.features.named_modules())["features.7"] # βœ… Adjust as needed
120
+ cam_model = GradCAMPlusPlus(model=model, target_layers=[target_layer])
121
+
122
+ # βœ… Get Grad-CAM++ map
123
+ cam_output = cam_model(input_tensor=input_tensor, targets=[ClassifierOutputTarget(class_index)])[0]
124
+
125
+ # βœ… Create RGB overlay
126
+ original_img = np.asarray(img.resize((224, 224)), dtype=np.float32) / 255.0
127
+ overlay = show_cam_on_image(original_img, cam_output, use_rgb=True)
128
+
129
+ # βœ… Create grayscale version
130
+ cam_normalized = np.uint8(255 * cam_output)
131
+
132
+ # βœ… Save overlay
133
+ gradcam_filename = f"gradcam_{timestamp}.png"
134
+ gradcam_file_path = os.path.join(OUTPUT_DIR, gradcam_filename)
135
+ cv2.imwrite(gradcam_file_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
136
+
137
+ # βœ… Save grayscale
138
+ gray_filename = f"gradcam_gray_{timestamp}.png"
139
+ gray_file_path = os.path.join(OUTPUT_DIR, gray_filename)
140
+ cv2.imwrite(gray_file_path, cam_normalized)
141
+
142
+ # βœ… Save results to database
143
+ conn = sqlite3.connect(DB_PATH)
144
+ cursor = conn.cursor()
145
+ cursor.execute("""
146
+ INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, gradcam_gray_filename, timestamp)
147
+ VALUES (?, ?, ?, ?, ?, ?)
148
+ """, (uploaded_filename, result, confidence, gradcam_filename, gray_filename, datetime.now().isoformat()))
149
+ conn.commit()
150
+ conn.close()
151
+
152
+ # βœ… Return results
153
+ return jsonify({
154
+ 'prediction': result,
155
+ 'confidence': confidence,
156
+ 'normal_probability': float(probabilities[0]),
157
+ 'early_glaucoma_probability': float(probabilities[1]),
158
+ 'advanced_glaucoma_probability': float(probabilities[2]),
159
+ 'gradcam_image': gradcam_filename,
160
+ 'gradcam_gray_image': gray_filename,
161
+ 'image_filename': uploaded_filename
162
+ })
163
+
164
+ except Exception as e:
165
+ return jsonify({'error': str(e)}), 500
166
+
167
+
168
+ @app.route('/results', methods=['GET'])
169
+ def results():
170
+ """List all results from the SQLite database."""
171
+ conn = sqlite3.connect(DB_PATH)
172
+ cursor = conn.cursor()
173
+ cursor.execute("SELECT * FROM results ORDER BY timestamp DESC")
174
+ results_data = cursor.fetchall()
175
+ conn.close()
176
+
177
+ results_list = []
178
+ for record in results_data:
179
+ results_list.append({
180
+ 'id': record[0],
181
+ 'image_filename': record[1],
182
+ 'prediction': record[2],
183
+ 'confidence': record[3],
184
+ 'gradcam_filename': record[4],
185
+ 'gradcam_gray_filename': record[5],
186
+ 'timestamp': record[6]
187
+ })
188
+
189
+ return jsonify(results_list)
190
+
191
+
192
+ @app.route('/gradcam/<filename>')
193
+ def get_gradcam(filename):
194
+ """Serve the Grad-CAM overlay image."""
195
+ filepath = os.path.join(OUTPUT_DIR, filename)
196
+ if os.path.exists(filepath):
197
+ return send_file(filepath, mimetype='image/png')
198
+ else:
199
+ return jsonify({'error': 'File not found'}), 404
200
+
201
+
202
+ @app.route('/image/<filename>')
203
+ def get_image(filename):
204
+ """Serve the original uploaded image."""
205
+ filepath = os.path.join(OUTPUT_DIR, filename)
206
+ if os.path.exists(filepath):
207
+ return send_file(filepath, mimetype='image/png')
208
+ else:
209
+ return jsonify({'error': 'File not found'}), 404
210
+
211
+
212
+ if __name__ == '__main__':
213
+ app.run(host='0.0.0.0', port=7860)