dhruv2842 commited on
Commit
0cdf25a
Β·
verified Β·
1 Parent(s): 5a322a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -10
app.py CHANGED
@@ -9,6 +9,7 @@ from datetime import datetime
9
  import sqlite3
10
  import torch.nn as nn
11
  import torchvision.models as models
 
12
 
13
  app = Flask(__name__)
14
 
@@ -46,13 +47,11 @@ from densenet_withglam import get_model_with_attention
46
  # βœ… Instantiate the model
47
  model = get_model_with_attention('densenet169', num_classes=3) # Will have GLAM
48
  model.load_state_dict(torch.load('densenet169_seed40_best.pt', map_location='cpu'))
49
- # Load your trained weights
50
  model.eval()
51
 
52
  # βœ… Class Names
53
  CLASS_NAMES = ["Advanced", "Early", "Normal"]
54
 
55
-
56
  # βœ… Transformation for input images
57
  transform = transforms.Compose([
58
  transforms.Resize(256),
@@ -61,6 +60,46 @@ transform = transforms.Compose([
61
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
62
  std=[0.229, 0.224, 0.225]),
63
  ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  @app.route('/')
65
  def home():
66
  """Check that the API is working."""
@@ -75,9 +114,10 @@ def test_file():
75
  else:
76
  return "❌ Model file NOT found."
77
 
 
78
  @app.route('/predict', methods=['POST'])
79
  def predict():
80
- """Perform prediction using PyTorch (3-class), save results, and save to SQLite database."""
81
  if 'file' not in request.files:
82
  return jsonify({'error': 'No file uploaded'}), 400
83
 
@@ -96,22 +136,40 @@ def predict():
96
  img = Image.open(uploaded_file_path).convert('RGB')
97
  input_tensor = transform(img).unsqueeze(0)
98
 
99
- with torch.no_grad():
100
- output = model(input_tensor)
101
- probabilities = F.softmax(output, dim=1).cpu().numpy()[0]
 
 
 
 
102
 
103
- # βœ… Get result
104
  class_index = np.argmax(probabilities)
105
  result = CLASS_NAMES[class_index]
106
  confidence = float(probabilities[class_index])
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # βœ… Save results to SQLite
109
  conn = sqlite3.connect(DB_PATH)
110
  cursor = conn.cursor()
111
  cursor.execute("""
112
  INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, timestamp)
113
  VALUES (?, ?, ?, ?, ?)
114
- """, (uploaded_filename, result, confidence, '', datetime.now().isoformat()))
115
  conn.commit()
116
  conn.close()
117
 
@@ -121,13 +179,14 @@ def predict():
121
  'normal_probability': float(probabilities[0]),
122
  'early_glaucoma_probability': float(probabilities[1]),
123
  'advanced_glaucoma_probability': float(probabilities[2]),
124
- 'gradcam_image': '', # Not used for now
125
  'image_filename': uploaded_filename
126
  })
127
 
128
  except Exception as e:
129
  return jsonify({'error': str(e)}), 500
130
 
 
131
  @app.route('/results', methods=['GET'])
132
  def results():
133
  """List all results from the SQLite database."""
@@ -150,15 +209,17 @@ def results():
150
 
151
  return jsonify(results_list)
152
 
 
153
  @app.route('/gradcam/<filename>')
154
  def get_gradcam(filename):
155
- """Serve the Grad-CAM overlay image (no-op for now)."""
156
  filepath = os.path.join(OUTPUT_DIR, filename)
157
  if os.path.exists(filepath):
158
  return send_file(filepath, mimetype='image/png')
159
  else:
160
  return jsonify({'error': 'File not found'}), 404
161
 
 
162
  @app.route('/image/<filename>')
163
  def get_image(filename):
164
  """Serve the original uploaded image."""
@@ -168,5 +229,6 @@ def get_image(filename):
168
  else:
169
  return jsonify({'error': 'File not found'}), 404
170
 
 
171
  if __name__ == '__main__':
172
  app.run(host='0.0.0.0', port=7860)
 
9
  import sqlite3
10
  import torch.nn as nn
11
  import torchvision.models as models
12
+ import cv2
13
 
14
  app = Flask(__name__)
15
 
 
47
  # βœ… Instantiate the model
48
  model = get_model_with_attention('densenet169', num_classes=3) # Will have GLAM
49
  model.load_state_dict(torch.load('densenet169_seed40_best.pt', map_location='cpu'))
 
50
  model.eval()
51
 
52
  # βœ… Class Names
53
  CLASS_NAMES = ["Advanced", "Early", "Normal"]
54
 
 
55
  # βœ… Transformation for input images
56
  transform = transforms.Compose([
57
  transforms.Resize(256),
 
60
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
61
  std=[0.229, 0.224, 0.225]),
62
  ])
63
+
64
+ # =========================
65
+ # GRAD-CAM IMPLEMENTATION
66
+ # =========================
67
+ class GradCAM:
68
+ """Grad-CAM for the target layer."""
69
+ def __init__(self, model, target_layer_name):
70
+ self.model = model
71
+ self.target_layer_name = target_layer_name
72
+ self.activations = None
73
+ self.gradients = None
74
+ self._register_hooks()
75
+
76
+ def _register_hooks(self):
77
+ """Register forward and backward hooks."""
78
+ for name, module in self.model.named_modules():
79
+ if name == self.target_layer_name:
80
+ module.register_forward_hook(self._forward_hook)
81
+ module.register_full_backward_hook(self._backward_hook)
82
+
83
+ def _forward_hook(self, module, input, output):
84
+ """Save activations."""
85
+ self.activations = output
86
+
87
+ def _backward_hook(self, module, grad_in, grad_out):
88
+ """Save gradients."""
89
+ self.gradients = grad_out[0]
90
+
91
+ def generate(self, class_index):
92
+ """Generate the Grad-CAM."""
93
+ if self.activations is None or self.gradients is None:
94
+ raise ValueError("Must run forward and backward passes first.")
95
+ weights = self.gradients.mean(dim=(2, 3), keepdim=True)
96
+ cam = (weights * self.activations).sum(dim=1, keepdim=True)
97
+ cam = F.relu(cam)
98
+ cam = cam.squeeze().cpu().numpy()
99
+ cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
100
+ return cam
101
+
102
+
103
  @app.route('/')
104
  def home():
105
  """Check that the API is working."""
 
114
  else:
115
  return "❌ Model file NOT found."
116
 
117
+
118
  @app.route('/predict', methods=['POST'])
119
  def predict():
120
+ """Perform prediction and save results (including Grad-CAM) to the database."""
121
  if 'file' not in request.files:
122
  return jsonify({'error': 'No file uploaded'}), 400
123
 
 
136
  img = Image.open(uploaded_file_path).convert('RGB')
137
  input_tensor = transform(img).unsqueeze(0)
138
 
139
+ # Grad-CAM setup
140
+ target_layer_name = "features.2.global_spatial_conv"
141
+ gradcam = GradCAM(model, target_layer_name)
142
+
143
+ # Forward pass
144
+ input_tensor.requires_grad = True
145
+ output = model(input_tensor)
146
 
147
+ probabilities = F.softmax(output, dim=1).cpu().numpy()[0]
148
  class_index = np.argmax(probabilities)
149
  result = CLASS_NAMES[class_index]
150
  confidence = float(probabilities[class_index])
151
 
152
+ # Backward pass for Grad-CAM
153
+ model.zero_grad()
154
+ output[0, class_index].backward()
155
+ cam = gradcam.generate(class_index)
156
+
157
+ # βœ… Create overlay
158
+ original_img = np.asarray(img.resize((224, 224)))
159
+ heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
160
+ overlay = cv2.addWeighted(original_img, 0.6, heatmap, 0.4, 0)
161
+
162
+ gradcam_filename = f"gradcam_{timestamp}.png"
163
+ gradcam_file_path = os.path.join(OUTPUT_DIR, gradcam_filename)
164
+ cv2.imwrite(gradcam_file_path, overlay)
165
+
166
  # βœ… Save results to SQLite
167
  conn = sqlite3.connect(DB_PATH)
168
  cursor = conn.cursor()
169
  cursor.execute("""
170
  INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, timestamp)
171
  VALUES (?, ?, ?, ?, ?)
172
+ """, (uploaded_filename, result, confidence, gradcam_filename, datetime.now().isoformat()))
173
  conn.commit()
174
  conn.close()
175
 
 
179
  'normal_probability': float(probabilities[0]),
180
  'early_glaucoma_probability': float(probabilities[1]),
181
  'advanced_glaucoma_probability': float(probabilities[2]),
182
+ 'gradcam_image': gradcam_filename,
183
  'image_filename': uploaded_filename
184
  })
185
 
186
  except Exception as e:
187
  return jsonify({'error': str(e)}), 500
188
 
189
+
190
  @app.route('/results', methods=['GET'])
191
  def results():
192
  """List all results from the SQLite database."""
 
209
 
210
  return jsonify(results_list)
211
 
212
+
213
  @app.route('/gradcam/<filename>')
214
  def get_gradcam(filename):
215
+ """Serve the Grad-CAM overlay image."""
216
  filepath = os.path.join(OUTPUT_DIR, filename)
217
  if os.path.exists(filepath):
218
  return send_file(filepath, mimetype='image/png')
219
  else:
220
  return jsonify({'error': 'File not found'}), 404
221
 
222
+
223
  @app.route('/image/<filename>')
224
  def get_image(filename):
225
  """Serve the original uploaded image."""
 
229
  else:
230
  return jsonify({'error': 'File not found'}), 404
231
 
232
+
233
  if __name__ == '__main__':
234
  app.run(host='0.0.0.0', port=7860)