dhruv2842 commited on
Commit
2d20fbb
·
verified ·
1 Parent(s): 1c0af0c

Upload 5 files

Browse files
Files changed (5) hide show
  1. Dockerfile +24 -0
  2. app.py +212 -0
  3. efficientnet_glam_best.pt +3 -0
  4. glam_efficient_model.py +103 -0
  5. requirements.txt +10 -0
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ # 2️⃣ Set working directory
4
+ WORKDIR /app
5
+
6
+ # 3️⃣ Install required system dependencies (fixes libGL and libgthread errors)
7
+ RUN apt-get update && \
8
+ apt-get install -y libgl1-mesa-glx libglib2.0-0 && \
9
+ rm -rf /var/lib/apt/lists/*
10
+
11
+ # 4️⃣ Copy requirements
12
+ COPY requirements.txt .
13
+
14
+ # 5️⃣ Install Python dependencies
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # 6️⃣ Copy all files from the root of your project
18
+ COPY . .
19
+
20
+ # 7️⃣ Expose the port
21
+ EXPOSE 7860
22
+
23
+ # 8️⃣ Command to run the app
24
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, send_file
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision import transforms
6
+ import os
7
+ import numpy as np
8
+ from datetime import datetime
9
+ import sqlite3
10
+ import torch.nn as nn
11
+ import cv2
12
+
13
+ # ✅ New Grad-CAM++ imports
14
+ from pytorch_grad_cam import GradCAMPlusPlus
15
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
16
+ from pytorch_grad_cam.utils.image import show_cam_on_image
17
+
18
+ # ✅ Import Hugging Face-style GLAM EfficientNet model
19
+ from glam_efficientnet_model import GLAMEfficientNetForClassification, GLAMEfficientNetConfig
20
+
21
+ app = Flask(__name__)
22
+
23
+ # ✅ Directory and database path
24
+ OUTPUT_DIR = '/tmp/results'
25
+ if not os.path.exists(OUTPUT_DIR):
26
+ os.makedirs(OUTPUT_DIR)
27
+
28
+ DB_PATH = os.path.join(OUTPUT_DIR, 'results.db')
29
+
30
+
31
+ def init_db():
32
+ """Initialize SQLite database for storing results."""
33
+ conn = sqlite3.connect(DB_PATH)
34
+ cursor = conn.cursor()
35
+ cursor.execute("""
36
+ CREATE TABLE IF NOT EXISTS results (
37
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
38
+ image_filename TEXT,
39
+ prediction TEXT,
40
+ confidence REAL,
41
+ gradcam_filename TEXT,
42
+ gradcam_gray_filename TEXT,
43
+ timestamp TEXT
44
+ )
45
+ """)
46
+ conn.commit()
47
+ conn.close()
48
+
49
+
50
+ init_db()
51
+
52
+ # ✅ Load GLAM EfficientNet Model
53
+ config = GLAMEfficientNetConfig()
54
+ model = GLAMEfficientNetForClassification(config)
55
+ model.load_state_dict(torch.load('efficientnet_glam_best.pt', map_location='cpu'))
56
+ model.eval()
57
+
58
+ # ✅ Class Names
59
+ CLASS_NAMES = ["Advanced", "Early", "Normal"]
60
+
61
+ # ✅ Transformation for input images
62
+ transform = transforms.Compose([
63
+ transforms.Resize(256),
64
+ transforms.CenterCrop(224),
65
+ transforms.ToTensor(),
66
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
67
+ std=[0.229, 0.224, 0.225]),
68
+ ])
69
+
70
+
71
+ @app.route('/')
72
+ def home():
73
+ """Check that the API is working."""
74
+ return "Glaucoma Detection Flask API (EfficientNet + GLAM) is running!"
75
+
76
+
77
+ @app.route("/test_file")
78
+ def test_file():
79
+ """Check if the .pt model file is present and readable."""
80
+ filepath = "efficientnet_glam_best.pt"
81
+ if os.path.exists(filepath):
82
+ return f"✅ Model file found at: {filepath}"
83
+ else:
84
+ return "❌ Model file NOT found."
85
+
86
+
87
+ @app.route('/predict', methods=['POST'])
88
+ def predict():
89
+ """Perform prediction and save results (including Grad-CAM++) to the database."""
90
+ if 'file' not in request.files:
91
+ return jsonify({'error': 'No file uploaded'}), 400
92
+
93
+ uploaded_file = request.files['file']
94
+ if uploaded_file.filename == '':
95
+ return jsonify({'error': 'No file selected'}), 400
96
+
97
+ try:
98
+ # ✅ Save the uploaded image
99
+ timestamp = int(datetime.now().timestamp())
100
+ uploaded_filename = f"uploaded_{timestamp}.png"
101
+ uploaded_file_path = os.path.join(OUTPUT_DIR, uploaded_filename)
102
+ uploaded_file.save(uploaded_file_path)
103
+
104
+ # ✅ Perform prediction
105
+ img = Image.open(uploaded_file_path).convert('RGB')
106
+ input_tensor = transform(img).unsqueeze(0)
107
+
108
+ # ✅ Get prediction
109
+ output = model(input_tensor) # Dict with "logits"
110
+ probabilities = F.softmax(output["logits"], dim=1).cpu().detach().numpy()[0]
111
+ class_index = np.argmax(probabilities)
112
+ result = CLASS_NAMES[class_index]
113
+ confidence = float(probabilities[class_index])
114
+
115
+ # ✅ Grad-CAM++ setup
116
+ # IMPORTANT: Choose the target layer from the GLAM EfficientNet model.
117
+ # For example, use the final convolutional block:
118
+ target_layer = model.features[-1]
119
+ cam_model = GradCAMPlusPlus(model=model, target_layers=[target_layer])
120
+
121
+ # ✅ Get Grad-CAM++ map
122
+ cam_output = cam_model(input_tensor=input_tensor, targets=[ClassifierOutputTarget(class_index)])[0]
123
+
124
+ # ✅ Create RGB overlay
125
+ original_img = np.asarray(img.resize((224, 224)), dtype=np.float32) / 255.0
126
+ overlay = show_cam_on_image(original_img, cam_output, use_rgb=True)
127
+
128
+ # ✅ Create grayscale version
129
+ cam_normalized = np.uint8(255 * cam_output)
130
+
131
+ # ✅ Save overlay
132
+ gradcam_filename = f"gradcam_{timestamp}.png"
133
+ gradcam_file_path = os.path.join(OUTPUT_DIR, gradcam_filename)
134
+ cv2.imwrite(gradcam_file_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
135
+
136
+ # ✅ Save grayscale
137
+ gray_filename = f"gradcam_gray_{timestamp}.png"
138
+ gray_file_path = os.path.join(OUTPUT_DIR, gray_filename)
139
+ cv2.imwrite(gray_file_path, cam_normalized)
140
+
141
+ # ✅ Save results to database
142
+ conn = sqlite3.connect(DB_PATH)
143
+ cursor = conn.cursor()
144
+ cursor.execute("""
145
+ INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, gradcam_gray_filename, timestamp)
146
+ VALUES (?, ?, ?, ?, ?, ?)
147
+ """, (uploaded_filename, result, confidence, gradcam_filename, gray_filename, datetime.now().isoformat()))
148
+ conn.commit()
149
+ conn.close()
150
+
151
+ # ✅ Return results
152
+ return jsonify({
153
+ 'prediction': result,
154
+ 'confidence': confidence,
155
+ 'normal_probability': float(probabilities[0]),
156
+ 'early_glaucoma_probability': float(probabilities[1]),
157
+ 'advanced_glaucoma_probability': float(probabilities[2]),
158
+ 'gradcam_image': gradcam_filename,
159
+ 'gradcam_gray_image': gray_filename,
160
+ 'image_filename': uploaded_filename
161
+ })
162
+
163
+ except Exception as e:
164
+ return jsonify({'error': str(e)}), 500
165
+
166
+
167
+ @app.route('/results', methods=['GET'])
168
+ def results():
169
+ """List all results from the SQLite database."""
170
+ conn = sqlite3.connect(DB_PATH)
171
+ cursor = conn.cursor()
172
+ cursor.execute("SELECT * FROM results ORDER BY timestamp DESC")
173
+ results_data = cursor.fetchall()
174
+ conn.close()
175
+
176
+ results_list = []
177
+ for record in results_data:
178
+ results_list.append({
179
+ 'id': record[0],
180
+ 'image_filename': record[1],
181
+ 'prediction': record[2],
182
+ 'confidence': record[3],
183
+ 'gradcam_filename': record[4],
184
+ 'gradcam_gray_filename': record[5],
185
+ 'timestamp': record[6]
186
+ })
187
+
188
+ return jsonify(results_list)
189
+
190
+
191
+ @app.route('/gradcam/<filename>')
192
+ def get_gradcam(filename):
193
+ """Serve the Grad-CAM overlay image."""
194
+ filepath = os.path.join(OUTPUT_DIR, filename)
195
+ if os.path.exists(filepath):
196
+ return send_file(filepath, mimetype='image/png')
197
+ else:
198
+ return jsonify({'error': 'File not found'}), 404
199
+
200
+
201
+ @app.route('/image/<filename>')
202
+ def get_image(filename):
203
+ """Serve the original uploaded image."""
204
+ filepath = os.path.join(OUTPUT_DIR, filename)
205
+ if os.path.exists(filepath):
206
+ return send_file(filepath, mimetype='image/png')
207
+ else:
208
+ return jsonify({'error': 'File not found'}), 404
209
+
210
+
211
+ if __name__ == '__main__':
212
+ app.run(host='0.0.0.0', port=7860)
efficientnet_glam_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bcdc2e2bc5aef943b6658e2e2e1fd62a856d860aef97e7f2bdc2ca3b03a8fe5b
3
+ size 45758832
glam_efficient_model.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel, PretrainedConfig
5
+ from transformers import EfficientNetModel
6
+ from typing import Optional, Union
7
+
8
+
9
+ # --------------------------------------------------
10
+ # Import your GLAM, SwinWindowAttention blocks here
11
+ # --------------------------------------------------
12
+ # from .glam_module import GLAM
13
+ # from .swin_module import SwinWindowAttention
14
+
15
+
16
+ class GLAMEfficientNetConfig(PretrainedConfig):
17
+ """Hugging Face-style configuration for GLAM EfficientNet."""
18
+ model_type = "glam_efficientnet"
19
+
20
+ def __init__(self,
21
+ num_classes: int = 3,
22
+ embed_dim: int = 512,
23
+ num_heads: int = 8,
24
+ window_size: int = 7,
25
+ reduction_ratio: int = 8,
26
+ dropout: float = 0.5,
27
+ **kwargs):
28
+ super().__init__(**kwargs)
29
+ self.num_classes = num_classes
30
+ self.embed_dim = embed_dim
31
+ self.num_heads = num_heads
32
+ self.window_size = window_size
33
+ self.reduction_ratio = reduction_ratio
34
+ self.dropout = dropout
35
+
36
+
37
+ class GLAMEfficientNetForClassification(PreTrainedModel):
38
+ """Hugging Face-style Model for EfficientNet + GLAM + Swin Architecture."""
39
+
40
+ config_class = GLAMEfficientNetConfig
41
+
42
+ def __init__(self, config: GLAMEfficientNetConfig):
43
+ super().__init__(config)
44
+
45
+ # 1) EfficientNet Backbone
46
+ self.features = EfficientNetModel.from_pretrained("google/efficientnet-b0").features
47
+ self.conv1x1 = nn.Conv2d(1280, config.embed_dim, kernel_size=1)
48
+
49
+ # 2) Swin Attention Block
50
+ self.swin_attn = SwinWindowAttention(
51
+ embed_dim=config.embed_dim,
52
+ window_size=config.window_size,
53
+ num_heads=config.num_heads,
54
+ dropout=config.dropout
55
+ )
56
+ self.pre_attn_norm = nn.LayerNorm(config.embed_dim)
57
+ self.post_attn_norm = nn.LayerNorm(config.embed_dim)
58
+
59
+ # 3) GLAM Block
60
+ self.glam = GLAM(in_channels=config.embed_dim, reduction_ratio=config.reduction_ratio)
61
+
62
+ # 4) Self-Adaptive Gating
63
+ self.gate_fc = nn.Linear(config.embed_dim, 1)
64
+
65
+ # Final classification
66
+ self.dropout = nn.Dropout(config.dropout)
67
+ self.classifier = nn.Linear(config.embed_dim, config.num_classes)
68
+
69
+ def forward(self, pixel_values, labels=None, **kwargs):
70
+ # 1) Extract EfficientNet Features
71
+ feats = self.features(pixel_values).last_hidden_state
72
+ feats = self.conv1x1(feats)
73
+
74
+ B, C, H, W = feats.shape
75
+
76
+ # 2) Transformer Branch
77
+ x_perm = feats.permute(0, 2, 3, 1).contiguous()
78
+ x_norm = self.pre_attn_norm(x_perm).permute(0, 3, 1, 2).contiguous()
79
+ x_norm = self.dropout(x_norm)
80
+
81
+ T_out = self.swin_attn(x_norm)
82
+
83
+ T_out = self.post_attn_norm(T_out.permute(0, 2, 3, 1).contiguous())
84
+ T_out = T_out.permute(0, 3, 1, 2).contiguous()
85
+
86
+ # 3) GLAM Branch
87
+ G_out = self.glam(feats)
88
+
89
+ # 4) Self-Adaptive Gating
90
+ gap_feats = F.adaptive_avg_pool2d(feats, (1, 1)).view(B, C)
91
+ g = torch.sigmoid(self.gate_fc(gap_feats)).view(B, 1, 1, 1)
92
+
93
+ F_out = g * T_out + (1 - g) * G_out
94
+
95
+ # 5) Final Pooling + Classifier
96
+ pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1)
97
+ logits = self.classifier(self.dropout(pooled))
98
+
99
+ loss = None
100
+ if labels is not None:
101
+ loss = F.cross_entropy(logits, labels)
102
+
103
+ return {"loss": loss, "logits": logits}
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Flask
2
+ torch
3
+ torchvision
4
+ Pillow
5
+ numpy
6
+ opencv-python
7
+
8
+ firebase-admin
9
+ psycopg2-binary
10
+ grad-cam