from flask import Flask, request, jsonify, send_file from PIL import Image import torch import torch.nn.functional as F from torchvision import transforms import os import numpy as np from datetime import datetime import sqlite3 import torch.nn as nn import torchvision.models as models import cv2 app = Flask(__name__) # ✅ Directory and database path OUTPUT_DIR = '/tmp/results' if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR) DB_PATH = os.path.join(OUTPUT_DIR, 'results.db') def init_db(): """Initialize SQLite database for storing results.""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS results ( id INTEGER PRIMARY KEY AUTOINCREMENT, image_filename TEXT, prediction TEXT, confidence REAL, gradcam_filename TEXT, timestamp TEXT ) """) conn.commit() conn.close() init_db() # ✅ Import your custom GLAM model from densenet_withglam import get_model_with_attention # ✅ Instantiate the model model = get_model_with_attention('densenet169', num_classes=3) # Will have GLAM model.load_state_dict(torch.load('densenet169_seed40_best.pt', map_location='cpu')) model.eval() # ✅ Class Names CLASS_NAMES = ["Advanced", "Early", "Normal"] # ✅ Transformation for input images transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ========================= # GRAD-CAM IMPLEMENTATION # ========================= class GradCAM: """Grad-CAM for the target layer.""" def __init__(self, model, target_layer_name): self.model = model self.target_layer_name = target_layer_name self.activations = None self.gradients = None self._register_hooks() def _register_hooks(self): """Register forward and backward hooks.""" for name, module in self.model.named_modules(): if name == self.target_layer_name: module.register_forward_hook(self._forward_hook) module.register_full_backward_hook(self._backward_hook) def _forward_hook(self, module, input, output): """Save activations.""" self.activations = output def _backward_hook(self, module, grad_in, grad_out): """Save gradients.""" self.gradients = grad_out[0] def generate(self, class_index): """Generate the Grad-CAM.""" if self.activations is None or self.gradients is None: raise ValueError("Must run forward and backward passes first.") weights = self.gradients.mean(dim=(2, 3), keepdim=True) cam = (weights * self.activations).sum(dim=1, keepdim=True) cam = F.relu(cam) cam = cam.squeeze().cpu().detach().numpy() cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) return cam @app.route('/') def home(): """Check that the API is working.""" return "Glaucoma Detection Flask API (3-Class Model) is running!" @app.route("/test_file") def test_file(): """Check if the .pt model file is present and readable.""" filepath = "densenet169_seed40_best2.pt" if os.path.exists(filepath): return f"✅ Model file found at: {filepath}" else: return "❌ Model file NOT found." @app.route('/predict', methods=['POST']) def predict(): """Perform prediction and save results (including Grad-CAM) to the database.""" if 'file' not in request.files: return jsonify({'error': 'No file uploaded'}), 400 uploaded_file = request.files['file'] if uploaded_file.filename == '': return jsonify({'error': 'No file selected'}), 400 try: # ✅ Save the uploaded image timestamp = int(datetime.now().timestamp()) uploaded_filename = f"uploaded_{timestamp}.png" uploaded_file_path = os.path.join(OUTPUT_DIR, uploaded_filename) uploaded_file.save(uploaded_file_path) # ✅ Perform prediction img = Image.open(uploaded_file_path).convert('RGB') input_tensor = transform(img).unsqueeze(0) # Grad-CAM setup target_layer_name = "features.2.local_spatial_conv3" gradcam = GradCAM(model, target_layer_name) # Forward pass output = model(input_tensor) probabilities = F.softmax(output, dim=1).cpu().detach().numpy()[0] class_index = np.argmax(probabilities) result = CLASS_NAMES[class_index] confidence = float(probabilities[class_index]) # Backward pass for Grad-CAM model.zero_grad() output[0, class_index].backward() cam = gradcam.generate(class_index) # ✅ Ensure cam is 2D if cam.ndim == 3: cam = cam[0] # ✅ Scale CAM and resize cam = np.uint8(255 * cam) cam = cv2.resize(cam, (224, 224)) # ✅ Create color overlay original_img = np.asarray(img.resize((224, 224))) heatmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET) overlay = cv2.addWeighted(original_img, 0.6, heatmap, 0.4, 0) # ✅ Save color overlay gradcam_filename = f"gradcam_{timestamp}.png" gradcam_file_path = os.path.join(OUTPUT_DIR, gradcam_filename) cv2.imwrite(gradcam_file_path, overlay) # ✅ Save grayscale overlay gray_filename = f"gradcam_gray_{timestamp}.png" gray_file_path = os.path.join(OUTPUT_DIR, gray_filename) cv2.imwrite(gray_file_path, cam) # ✅ Save results to database conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute(""" INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, timestamp) VALUES (?, ?, ?, ?, ?) """, (uploaded_filename, result, confidence, gradcam_filename, datetime.now().isoformat())) conn.commit() conn.close() # ✅ Return results return jsonify({ 'prediction': result, 'confidence': confidence, 'normal_probability': float(probabilities[0]), 'early_glaucoma_probability': float(probabilities[1]), 'advanced_glaucoma_probability': float(probabilities[2]), 'gradcam_image': gradcam_filename, 'gradcam_gray_image': gray_filename, 'image_filename': uploaded_filename }) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/results', methods=['GET']) def results(): """List all results from the SQLite database.""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute("SELECT * FROM results ORDER BY timestamp DESC") results_data = cursor.fetchall() conn.close() results_list = [] for record in results_data: results_list.append({ 'id': record[0], 'image_filename': record[1], 'prediction': record[2], 'confidence': record[3], 'gradcam_filename': record[4], 'timestamp': record[5] }) return jsonify(results_list) @app.route('/gradcam/') def get_gradcam(filename): """Serve the Grad-CAM overlay image.""" filepath = os.path.join(OUTPUT_DIR, filename) if os.path.exists(filepath): return send_file(filepath, mimetype='image/png') else: return jsonify({'error': 'File not found'}), 404 @app.route('/image/') def get_image(filename): """Serve the original uploaded image.""" filepath = os.path.join(OUTPUT_DIR, filename) if os.path.exists(filepath): return send_file(filepath, mimetype='image/png') else: return jsonify({'error': 'File not found'}), 404 if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)