|
from flask import Flask, request, jsonify, send_file |
|
from PIL import Image |
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
import os |
|
import numpy as np |
|
from datetime import datetime |
|
import sqlite3 |
|
import torch.nn as nn |
|
import cv2 |
|
import json |
|
|
|
|
|
from pytorch_grad_cam import GradCAMPlusPlus |
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
OUTPUT_DIR = '/tmp/results' |
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
DB_PATH = os.path.join(OUTPUT_DIR, 'results.db') |
|
|
|
|
|
def init_db(): |
|
"""Initialize SQLite database for storing results.""" |
|
conn = sqlite3.connect(DB_PATH) |
|
cursor = conn.cursor() |
|
cursor.execute(""" |
|
CREATE TABLE IF NOT EXISTS results ( |
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
image_filename TEXT, |
|
prediction TEXT, |
|
confidence REAL, |
|
gradcam_filename TEXT, |
|
gradcam_gray_filename TEXT, |
|
timestamp TEXT |
|
) |
|
""") |
|
conn.commit() |
|
conn.close() |
|
|
|
|
|
init_db() |
|
|
|
|
|
|
|
from efficientnet_transformer_glam import EfficientNetb0_TransformerGLAM |
|
|
|
|
|
|
|
model = EfficientNetb0_TransformerGLAM( |
|
num_classes=3, |
|
embed_dim=512, |
|
num_heads=8, |
|
mlp_dim=512, |
|
dropout=0.5, |
|
window_size=7, |
|
reduction_ratio=8 |
|
) |
|
|
|
|
|
model.load_state_dict(torch.load('densenet169_seed40_best.pt', map_location='cpu')) |
|
model.eval() |
|
|
|
|
|
CLASS_NAMES = ["Advanced", "Early", "Normal"] |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
@app.route('/') |
|
def home(): |
|
"""Check that the API is working.""" |
|
return "Glaucoma Detection Flask API (EfficientNetB0_TransformerGLAM Model) is running!" |
|
|
|
@app.route("/test_file") |
|
def test_file(): |
|
"""Check if the .pt model file is present and readable.""" |
|
filepath = "densenet169_seed40_best.pt" |
|
if os.path.exists(filepath): |
|
return f"β
Model file found at: {filepath}" |
|
else: |
|
return "β Model file NOT found." |
|
|
|
@app.route('/predict', methods=['POST']) |
|
def predict(): |
|
"""Perform prediction and save results (including Grad-CAM++) to the database.""" |
|
if 'file' not in request.files: |
|
return jsonify({'error': 'No file uploaded.'}), 400 |
|
|
|
uploaded_file = request.files['file'] |
|
if uploaded_file.filename == '': |
|
return jsonify({'error': 'No file selected.'}), 400 |
|
|
|
try: |
|
|
|
timestamp = int(datetime.now().timestamp()) |
|
uploaded_filename = f"uploaded_{timestamp}.png" |
|
uploaded_file_path = os.path.join(OUTPUT_DIR, uploaded_filename) |
|
uploaded_file.save(uploaded_file_path) |
|
|
|
|
|
img = Image.open(uploaded_file_path).convert('RGB') |
|
input_tensor = transform(img).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(input_tensor) |
|
probabilities = F.softmax(output, dim=1).cpu().numpy()[0] |
|
class_index = np.argmax(probabilities) |
|
result = CLASS_NAMES[class_index] |
|
confidence = float(probabilities[class_index]) |
|
|
|
|
|
target_layer = model.feature_extractor[-1] |
|
cam_model = GradCAMPlusPlus(model=model, target_layers=[target_layer]) |
|
|
|
cam_output = cam_model(input_tensor=input_tensor, |
|
targets=[ClassifierOutputTarget(class_index)])[0] |
|
|
|
|
|
original_img = np.asarray(img.resize((224, 224)), dtype=np.float32) / 255.0 |
|
overlay = show_cam_on_image(original_img, cam_output, use_rgb=True) |
|
|
|
|
|
cam_normalized = np.uint8(255 * cam_output) |
|
|
|
|
|
gradcam_filename = f"gradcam_{timestamp}.png" |
|
gradcam_file_path = os.path.join(OUTPUT_DIR, gradcam_filename) |
|
cv2.imwrite(gradcam_file_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)) |
|
|
|
|
|
gray_filename = f"gradcam_gray_{timestamp}.png" |
|
gray_file_path = os.path.join(OUTPUT_DIR, gray_filename) |
|
cv2.imwrite(gray_file_path, cam_normalized) |
|
|
|
|
|
conn = sqlite3.connect(DB_PATH) |
|
cursor = conn.cursor() |
|
cursor.execute(""" |
|
INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, gradcam_gray_filename, timestamp) |
|
VALUES (?, ?, ?, ?, ?, ?) |
|
""", (uploaded_filename, result, confidence, gradcam_filename, gray_filename, datetime.now().isoformat())) |
|
conn.commit() |
|
conn.close() |
|
|
|
|
|
return jsonify({ |
|
'prediction': result, |
|
'confidence': confidence, |
|
'normal_probability': float(probabilities[0]), |
|
'early_glaucoma_probability': float(probabilities[1]), |
|
'advanced_glaucoma_probability': float(probabilities[2]), |
|
'gradcam_image': gradcam_filename, |
|
'gradcam_gray_image': gray_filename, |
|
'image_filename': uploaded_filename |
|
}) |
|
|
|
except Exception as e: |
|
return jsonify({'error': str(e)}), 500 |
|
|
|
|
|
@app.route('/results', methods=['GET']) |
|
def results(): |
|
"""List all results from the SQLite database.""" |
|
conn = sqlite3.connect(DB_PATH) |
|
cursor = conn.cursor() |
|
cursor.execute("SELECT * FROM results ORDER BY timestamp DESC") |
|
results_data = cursor.fetchall() |
|
conn.close() |
|
|
|
results_list = [] |
|
for record in results_data: |
|
results_list.append({ |
|
'id': record[0], |
|
'image_filename': record[1], |
|
'prediction': record[2], |
|
'confidence': record[3], |
|
'gradcam_filename': record[4], |
|
'gradcam_gray_filename': record[5], |
|
'timestamp': record[6] |
|
}) |
|
|
|
return jsonify(results_list) |
|
|
|
|
|
@app.route('/gradcam/<filename>') |
|
def get_gradcam(filename): |
|
"""Serve the Grad-CAM overlay image.""" |
|
filepath = os.path.join(OUTPUT_DIR, filename) |
|
if os.path.exists(filepath): |
|
return send_file(filepath, mimetype='image/png') |
|
else: |
|
return jsonify({'error': 'File not found.'}), 404 |
|
|
|
|
|
@app.route('/image/<filename>') |
|
def get_image(filename): |
|
"""Serve the original uploaded image.""" |
|
filepath = os.path.join(OUTPUT_DIR, filename) |
|
if os.path.exists(filepath): |
|
return send_file(filepath, mimetype='image/png') |
|
else: |
|
return jsonify({'error': 'File not found.'}), 404 |
|
|
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=7860) |
|
|
|
|