glaucoma / app.py
dhruv2842's picture
Update app.py
2b314ce verified
raw
history blame
7 kB
from flask import Flask, request, jsonify, send_file
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision import transforms
import os
import numpy as np
from datetime import datetime
import sqlite3
import torch.nn as nn
import cv2
import json
# Grad-CAM++ imports
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
app = Flask(__name__)
# βœ… Directory and database
OUTPUT_DIR = '/tmp/results'
os.makedirs(OUTPUT_DIR, exist_ok=True)
DB_PATH = os.path.join(OUTPUT_DIR, 'results.db')
def init_db():
"""Initialize SQLite database for storing results."""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_filename TEXT,
prediction TEXT,
confidence REAL,
gradcam_filename TEXT,
gradcam_gray_filename TEXT,
timestamp TEXT
)
""")
conn.commit()
conn.close()
init_db()
# βœ… Import your EfficientNetB0_TransformerGLAM model
from efficientnet_transformer_glam import EfficientNetb0_TransformerGLAM # Ensure this is in the path
# βœ… Instantiate the model
model = EfficientNetb0_TransformerGLAM(
num_classes=3,
embed_dim=512,
num_heads=8,
mlp_dim=512,
dropout=0.5,
window_size=7,
reduction_ratio=8
)
# βœ… Load the trained checkpoint
model.load_state_dict(torch.load('densenet169_seed40_best.pt', map_location='cpu'))
model.eval()
# βœ… Class Names
CLASS_NAMES = ["Advanced", "Early", "Normal"]
# βœ… Transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
@app.route('/')
def home():
"""Check that the API is working."""
return "Glaucoma Detection Flask API (EfficientNetB0_TransformerGLAM Model) is running!"
@app.route("/test_file")
def test_file():
"""Check if the .pt model file is present and readable."""
filepath = "densenet169_seed40_best.pt"
if os.path.exists(filepath):
return f"βœ… Model file found at: {filepath}"
else:
return "❌ Model file NOT found."
@app.route('/predict', methods=['POST'])
def predict():
"""Perform prediction and save results (including Grad-CAM++) to the database."""
if 'file' not in request.files:
return jsonify({'error': 'No file uploaded.'}), 400
uploaded_file = request.files['file']
if uploaded_file.filename == '':
return jsonify({'error': 'No file selected.'}), 400
try:
# βœ… Save the uploaded image
timestamp = int(datetime.now().timestamp())
uploaded_filename = f"uploaded_{timestamp}.png"
uploaded_file_path = os.path.join(OUTPUT_DIR, uploaded_filename)
uploaded_file.save(uploaded_file_path)
# βœ… Perform prediction
img = Image.open(uploaded_file_path).convert('RGB')
input_tensor = transform(img).unsqueeze(0)
# Model Inference
with torch.no_grad():
output = model(input_tensor)
probabilities = F.softmax(output, dim=1).cpu().numpy()[0]
class_index = np.argmax(probabilities)
result = CLASS_NAMES[class_index]
confidence = float(probabilities[class_index])
# βœ… Grad-CAM++ setup
target_layer = model.feature_extractor[-1] # Final block of EfficientNet feature extractor
cam_model = GradCAMPlusPlus(model=model, target_layers=[target_layer])
cam_output = cam_model(input_tensor=input_tensor,
targets=[ClassifierOutputTarget(class_index)])[0]
# βœ… Create RGB overlay
original_img = np.asarray(img.resize((224, 224)), dtype=np.float32) / 255.0
overlay = show_cam_on_image(original_img, cam_output, use_rgb=True)
# βœ… Create grayscale version
cam_normalized = np.uint8(255 * cam_output)
# βœ… Save overlay
gradcam_filename = f"gradcam_{timestamp}.png"
gradcam_file_path = os.path.join(OUTPUT_DIR, gradcam_filename)
cv2.imwrite(gradcam_file_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
# βœ… Save grayscale
gray_filename = f"gradcam_gray_{timestamp}.png"
gray_file_path = os.path.join(OUTPUT_DIR, gray_filename)
cv2.imwrite(gray_file_path, cam_normalized)
# βœ… Save results to database
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute("""
INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, gradcam_gray_filename, timestamp)
VALUES (?, ?, ?, ?, ?, ?)
""", (uploaded_filename, result, confidence, gradcam_filename, gray_filename, datetime.now().isoformat()))
conn.commit()
conn.close()
# βœ… Return results
return jsonify({
'prediction': result,
'confidence': confidence,
'normal_probability': float(probabilities[0]),
'early_glaucoma_probability': float(probabilities[1]),
'advanced_glaucoma_probability': float(probabilities[2]),
'gradcam_image': gradcam_filename,
'gradcam_gray_image': gray_filename,
'image_filename': uploaded_filename
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/results', methods=['GET'])
def results():
"""List all results from the SQLite database."""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute("SELECT * FROM results ORDER BY timestamp DESC")
results_data = cursor.fetchall()
conn.close()
results_list = []
for record in results_data:
results_list.append({
'id': record[0],
'image_filename': record[1],
'prediction': record[2],
'confidence': record[3],
'gradcam_filename': record[4],
'gradcam_gray_filename': record[5],
'timestamp': record[6]
})
return jsonify(results_list)
@app.route('/gradcam/<filename>')
def get_gradcam(filename):
"""Serve the Grad-CAM overlay image."""
filepath = os.path.join(OUTPUT_DIR, filename)
if os.path.exists(filepath):
return send_file(filepath, mimetype='image/png')
else:
return jsonify({'error': 'File not found.'}), 404
@app.route('/image/<filename>')
def get_image(filename):
"""Serve the original uploaded image."""
filepath = os.path.join(OUTPUT_DIR, filename)
if os.path.exists(filepath):
return send_file(filepath, mimetype='image/png')
else:
return jsonify({'error': 'File not found.'}), 404
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)