Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify, send_file | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
from torchvision import transforms | |
import os | |
import numpy as np | |
from datetime import datetime | |
import sqlite3 | |
import torch.nn as nn | |
import torchvision.models as models | |
import cv2 | |
app = Flask(__name__) | |
# β Directory and database path | |
OUTPUT_DIR = '/tmp/results' | |
if not os.path.exists(OUTPUT_DIR): | |
os.makedirs(OUTPUT_DIR) | |
DB_PATH = os.path.join(OUTPUT_DIR, 'results.db') | |
def init_db(): | |
"""Initialize SQLite database for storing results.""" | |
conn = sqlite3.connect(DB_PATH) | |
cursor = conn.cursor() | |
cursor.execute(""" | |
CREATE TABLE IF NOT EXISTS results ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
image_filename TEXT, | |
prediction TEXT, | |
confidence REAL, | |
gradcam_filename TEXT, | |
timestamp TEXT | |
) | |
""") | |
conn.commit() | |
conn.close() | |
init_db() | |
# β Import your custom GLAM model | |
from densenet_withglam import get_model_with_attention | |
# β Instantiate the model | |
model = get_model_with_attention('densenet169', num_classes=3) # Will have GLAM | |
model.load_state_dict(torch.load('densenet169_seed40_best.pt', map_location='cpu')) | |
model.eval() | |
# β Class Names | |
CLASS_NAMES = ["Advanced", "Early", "Normal"] | |
# β Transformation for input images | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
# ========================= | |
# GRAD-CAM IMPLEMENTATION | |
# ========================= | |
class GradCAM: | |
"""Grad-CAM for the target layer.""" | |
def __init__(self, model, target_layer_name): | |
self.model = model | |
self.target_layer_name = target_layer_name | |
self.activations = None | |
self.gradients = None | |
self._register_hooks() | |
def _register_hooks(self): | |
"""Register forward and backward hooks.""" | |
for name, module in self.model.named_modules(): | |
if name == self.target_layer_name: | |
module.register_forward_hook(self._forward_hook) | |
module.register_full_backward_hook(self._backward_hook) | |
def _forward_hook(self, module, input, output): | |
"""Save activations.""" | |
self.activations = output | |
def _backward_hook(self, module, grad_in, grad_out): | |
"""Save gradients.""" | |
self.gradients = grad_out[0] | |
def generate(self, class_index): | |
"""Generate the Grad-CAM.""" | |
if self.activations is None or self.gradients is None: | |
raise ValueError("Must run forward and backward passes first.") | |
weights = self.gradients.mean(dim=(2, 3), keepdim=True) | |
cam = (weights * self.activations).sum(dim=1, keepdim=True) | |
cam = F.relu(cam) | |
cam = cam.squeeze().cpu().detach().numpy() | |
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) | |
return cam | |
def home(): | |
"""Check that the API is working.""" | |
return "Glaucoma Detection Flask API (3-Class Model) is running!" | |
def test_file(): | |
"""Check if the .pt model file is present and readable.""" | |
filepath = "densenet169_seed40_best2.pt" | |
if os.path.exists(filepath): | |
return f"β Model file found at: {filepath}" | |
else: | |
return "β Model file NOT found." | |
def predict(): | |
"""Perform prediction and save results (including Grad-CAM) to the database.""" | |
if 'file' not in request.files: | |
return jsonify({'error': 'No file uploaded'}), 400 | |
uploaded_file = request.files['file'] | |
if uploaded_file.filename == '': | |
return jsonify({'error': 'No file selected'}), 400 | |
try: | |
# β Save the uploaded image | |
timestamp = int(datetime.now().timestamp()) | |
uploaded_filename = f"uploaded_{timestamp}.png" | |
uploaded_file_path = os.path.join(OUTPUT_DIR, uploaded_filename) | |
uploaded_file.save(uploaded_file_path) | |
# β Perform prediction | |
img = Image.open(uploaded_file_path).convert('RGB') | |
input_tensor = transform(img).unsqueeze(0) | |
# Grad-CAM setup | |
target_layer_name = "features.2.local_spatial_conv3" | |
gradcam = GradCAM(model, target_layer_name) | |
# Forward pass | |
output = model(input_tensor) | |
probabilities = F.softmax(output, dim=1).cpu().detach().numpy()[0] | |
class_index = np.argmax(probabilities) | |
result = CLASS_NAMES[class_index] | |
confidence = float(probabilities[class_index]) | |
# Backward pass for Grad-CAM | |
model.zero_grad() | |
output[0, class_index].backward() | |
cam = gradcam.generate(class_index) | |
# β Ensure cam is 2D | |
if cam.ndim == 3: | |
cam = cam[0] | |
# β Scale CAM and resize | |
cam = np.uint8(255 * cam) | |
cam = cv2.resize(cam, (224, 224)) | |
# β Create color overlay | |
original_img = np.asarray(img.resize((224, 224))) | |
heatmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET) | |
overlay = cv2.addWeighted(original_img, 0.6, heatmap, 0.4, 0) | |
# β Save color overlay | |
gradcam_filename = f"gradcam_{timestamp}.png" | |
gradcam_file_path = os.path.join(OUTPUT_DIR, gradcam_filename) | |
cv2.imwrite(gradcam_file_path, overlay) | |
# β Save grayscale overlay | |
gray_filename = f"gradcam_gray_{timestamp}.png" | |
gray_file_path = os.path.join(OUTPUT_DIR, gray_filename) | |
cv2.imwrite(gray_file_path, cam) | |
# β Save results to database | |
conn = sqlite3.connect(DB_PATH) | |
cursor = conn.cursor() | |
cursor.execute(""" | |
INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, timestamp) | |
VALUES (?, ?, ?, ?, ?) | |
""", (uploaded_filename, result, confidence, gradcam_filename, datetime.now().isoformat())) | |
conn.commit() | |
conn.close() | |
# β Return results | |
return jsonify({ | |
'prediction': result, | |
'confidence': confidence, | |
'normal_probability': float(probabilities[0]), | |
'early_glaucoma_probability': float(probabilities[1]), | |
'advanced_glaucoma_probability': float(probabilities[2]), | |
'gradcam_image': gradcam_filename, | |
'gradcam_gray_image': gray_filename, | |
'image_filename': uploaded_filename | |
}) | |
except Exception as e: | |
return jsonify({'error': str(e)}), 500 | |
def results(): | |
"""List all results from the SQLite database.""" | |
conn = sqlite3.connect(DB_PATH) | |
cursor = conn.cursor() | |
cursor.execute("SELECT * FROM results ORDER BY timestamp DESC") | |
results_data = cursor.fetchall() | |
conn.close() | |
results_list = [] | |
for record in results_data: | |
results_list.append({ | |
'id': record[0], | |
'image_filename': record[1], | |
'prediction': record[2], | |
'confidence': record[3], | |
'gradcam_filename': record[4], | |
'timestamp': record[5] | |
}) | |
return jsonify(results_list) | |
def get_gradcam(filename): | |
"""Serve the Grad-CAM overlay image.""" | |
filepath = os.path.join(OUTPUT_DIR, filename) | |
if os.path.exists(filepath): | |
return send_file(filepath, mimetype='image/png') | |
else: | |
return jsonify({'error': 'File not found'}), 404 | |
def get_image(filename): | |
"""Serve the original uploaded image.""" | |
filepath = os.path.join(OUTPUT_DIR, filename) | |
if os.path.exists(filepath): | |
return send_file(filepath, mimetype='image/png') | |
else: | |
return jsonify({'error': 'File not found'}), 404 | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860) | |