dhruv2842 commited on
Commit
2b314ce
Β·
verified Β·
1 Parent(s): d5b9446

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -29
app.py CHANGED
@@ -9,23 +9,18 @@ from datetime import datetime
9
  import sqlite3
10
  import torch.nn as nn
11
  import cv2
 
12
 
13
- # βœ… New Grad-CAM++ imports
14
  from pytorch_grad_cam import GradCAMPlusPlus
15
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
16
  from pytorch_grad_cam.utils.image import show_cam_on_image
17
 
18
- # βœ… Import Hugging Face-style GLAM EfficientNet model
19
- from glam_efficientnet_model import GLAMEfficientNetForClassification, GLAMEfficientNetConfig
20
- from glam_module import GLAM
21
- from swin_module import SwinWindowAttention
22
-
23
  app = Flask(__name__)
24
 
25
- # βœ… Directory and database path
26
  OUTPUT_DIR = '/tmp/results'
27
- if not os.path.exists(OUTPUT_DIR):
28
- os.makedirs(OUTPUT_DIR)
29
 
30
  DB_PATH = os.path.join(OUTPUT_DIR, 'results.db')
31
 
@@ -51,18 +46,30 @@ def init_db():
51
 
52
  init_db()
53
 
54
- # βœ… Load GLAM EfficientNet Model
55
- config = GLAMEfficientNetConfig()
56
- model = GLAMEfficientNetForClassification(
57
- config, glam_module_cls=GLAM, swin_module_cls=SwinWindowAttention
 
 
 
 
 
 
 
 
 
 
58
  )
59
- model.load_state_dict(torch.load('efficientnet_glam_best.pt', map_location='cpu'))
 
 
60
  model.eval()
61
 
62
  # βœ… Class Names
63
  CLASS_NAMES = ["Advanced", "Early", "Normal"]
64
 
65
- # βœ… Transformation for input images
66
  transform = transforms.Compose([
67
  transforms.Resize(256),
68
  transforms.CenterCrop(224),
@@ -74,27 +81,26 @@ transform = transforms.Compose([
74
  @app.route('/')
75
  def home():
76
  """Check that the API is working."""
77
- return "Glaucoma Detection Flask API (GLAM EfficientNet) is running!"
78
 
79
  @app.route("/test_file")
80
  def test_file():
81
  """Check if the .pt model file is present and readable."""
82
- filepath = "efficientnet_glam_best.pt"
83
  if os.path.exists(filepath):
84
  return f"βœ… Model file found at: {filepath}"
85
  else:
86
  return "❌ Model file NOT found."
87
 
88
-
89
  @app.route('/predict', methods=['POST'])
90
  def predict():
91
  """Perform prediction and save results (including Grad-CAM++) to the database."""
92
  if 'file' not in request.files:
93
- return jsonify({'error': 'No file uploaded'}), 400
94
 
95
  uploaded_file = request.files['file']
96
  if uploaded_file.filename == '':
97
- return jsonify({'error': 'No file selected'}), 400
98
 
99
  try:
100
  # βœ… Save the uploaded image
@@ -107,20 +113,20 @@ def predict():
107
  img = Image.open(uploaded_file_path).convert('RGB')
108
  input_tensor = transform(img).unsqueeze(0)
109
 
110
- # βœ… Get prediction
111
- output = model(input_tensor) # Dict with "logits"
112
- probabilities = F.softmax(output["logits"], dim=1).cpu().detach().numpy()[0]
 
113
  class_index = np.argmax(probabilities)
114
  result = CLASS_NAMES[class_index]
115
  confidence = float(probabilities[class_index])
116
 
117
  # βœ… Grad-CAM++ setup
118
- # Target the final convolutional output. In GLAM EfficientNet, this is `model.features`
119
- target_layer = dict(model.features.named_modules())["features.7"] # βœ… Adjust as needed
120
  cam_model = GradCAMPlusPlus(model=model, target_layers=[target_layer])
121
 
122
- # βœ… Get Grad-CAM++ map
123
- cam_output = cam_model(input_tensor=input_tensor, targets=[ClassifierOutputTarget(class_index)])[0]
124
 
125
  # βœ… Create RGB overlay
126
  original_img = np.asarray(img.resize((224, 224)), dtype=np.float32) / 255.0
@@ -196,7 +202,7 @@ def get_gradcam(filename):
196
  if os.path.exists(filepath):
197
  return send_file(filepath, mimetype='image/png')
198
  else:
199
- return jsonify({'error': 'File not found'}), 404
200
 
201
 
202
  @app.route('/image/<filename>')
@@ -206,8 +212,9 @@ def get_image(filename):
206
  if os.path.exists(filepath):
207
  return send_file(filepath, mimetype='image/png')
208
  else:
209
- return jsonify({'error': 'File not found'}), 404
210
 
211
 
212
  if __name__ == '__main__':
213
  app.run(host='0.0.0.0', port=7860)
 
 
9
  import sqlite3
10
  import torch.nn as nn
11
  import cv2
12
+ import json
13
 
14
+ # Grad-CAM++ imports
15
  from pytorch_grad_cam import GradCAMPlusPlus
16
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
17
  from pytorch_grad_cam.utils.image import show_cam_on_image
18
 
 
 
 
 
 
19
  app = Flask(__name__)
20
 
21
+ # βœ… Directory and database
22
  OUTPUT_DIR = '/tmp/results'
23
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
 
24
 
25
  DB_PATH = os.path.join(OUTPUT_DIR, 'results.db')
26
 
 
46
 
47
  init_db()
48
 
49
+
50
+ # βœ… Import your EfficientNetB0_TransformerGLAM model
51
+ from efficientnet_transformer_glam import EfficientNetb0_TransformerGLAM # Ensure this is in the path
52
+
53
+
54
+ # βœ… Instantiate the model
55
+ model = EfficientNetb0_TransformerGLAM(
56
+ num_classes=3,
57
+ embed_dim=512,
58
+ num_heads=8,
59
+ mlp_dim=512,
60
+ dropout=0.5,
61
+ window_size=7,
62
+ reduction_ratio=8
63
  )
64
+
65
+ # βœ… Load the trained checkpoint
66
+ model.load_state_dict(torch.load('densenet169_seed40_best.pt', map_location='cpu'))
67
  model.eval()
68
 
69
  # βœ… Class Names
70
  CLASS_NAMES = ["Advanced", "Early", "Normal"]
71
 
72
+ # βœ… Transforms
73
  transform = transforms.Compose([
74
  transforms.Resize(256),
75
  transforms.CenterCrop(224),
 
81
  @app.route('/')
82
  def home():
83
  """Check that the API is working."""
84
+ return "Glaucoma Detection Flask API (EfficientNetB0_TransformerGLAM Model) is running!"
85
 
86
  @app.route("/test_file")
87
  def test_file():
88
  """Check if the .pt model file is present and readable."""
89
+ filepath = "densenet169_seed40_best.pt"
90
  if os.path.exists(filepath):
91
  return f"βœ… Model file found at: {filepath}"
92
  else:
93
  return "❌ Model file NOT found."
94
 
 
95
  @app.route('/predict', methods=['POST'])
96
  def predict():
97
  """Perform prediction and save results (including Grad-CAM++) to the database."""
98
  if 'file' not in request.files:
99
+ return jsonify({'error': 'No file uploaded.'}), 400
100
 
101
  uploaded_file = request.files['file']
102
  if uploaded_file.filename == '':
103
+ return jsonify({'error': 'No file selected.'}), 400
104
 
105
  try:
106
  # βœ… Save the uploaded image
 
113
  img = Image.open(uploaded_file_path).convert('RGB')
114
  input_tensor = transform(img).unsqueeze(0)
115
 
116
+ # Model Inference
117
+ with torch.no_grad():
118
+ output = model(input_tensor)
119
+ probabilities = F.softmax(output, dim=1).cpu().numpy()[0]
120
  class_index = np.argmax(probabilities)
121
  result = CLASS_NAMES[class_index]
122
  confidence = float(probabilities[class_index])
123
 
124
  # βœ… Grad-CAM++ setup
125
+ target_layer = model.feature_extractor[-1] # Final block of EfficientNet feature extractor
 
126
  cam_model = GradCAMPlusPlus(model=model, target_layers=[target_layer])
127
 
128
+ cam_output = cam_model(input_tensor=input_tensor,
129
+ targets=[ClassifierOutputTarget(class_index)])[0]
130
 
131
  # βœ… Create RGB overlay
132
  original_img = np.asarray(img.resize((224, 224)), dtype=np.float32) / 255.0
 
202
  if os.path.exists(filepath):
203
  return send_file(filepath, mimetype='image/png')
204
  else:
205
+ return jsonify({'error': 'File not found.'}), 404
206
 
207
 
208
  @app.route('/image/<filename>')
 
212
  if os.path.exists(filepath):
213
  return send_file(filepath, mimetype='image/png')
214
  else:
215
+ return jsonify({'error': 'File not found.'}), 404
216
 
217
 
218
  if __name__ == '__main__':
219
  app.run(host='0.0.0.0', port=7860)
220
+