|
from flask import Flask, request, jsonify, send_file
|
|
from PIL import Image
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torchvision import transforms
|
|
import os
|
|
import numpy as np
|
|
from datetime import datetime
|
|
import sqlite3
|
|
import torch.nn as nn
|
|
import cv2
|
|
|
|
|
|
from pytorch_grad_cam import GradCAMPlusPlus
|
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
|
from pytorch_grad_cam.utils.image import show_cam_on_image
|
|
|
|
|
|
from glam_efficientnet_model import GLAMEfficientNetForClassification, GLAMEfficientNetConfig
|
|
|
|
app = Flask(__name__)
|
|
|
|
|
|
OUTPUT_DIR = '/tmp/results'
|
|
if not os.path.exists(OUTPUT_DIR):
|
|
os.makedirs(OUTPUT_DIR)
|
|
|
|
DB_PATH = os.path.join(OUTPUT_DIR, 'results.db')
|
|
|
|
|
|
def init_db():
|
|
"""Initialize SQLite database for storing results."""
|
|
conn = sqlite3.connect(DB_PATH)
|
|
cursor = conn.cursor()
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS results (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
image_filename TEXT,
|
|
prediction TEXT,
|
|
confidence REAL,
|
|
gradcam_filename TEXT,
|
|
gradcam_gray_filename TEXT,
|
|
timestamp TEXT
|
|
)
|
|
""")
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
|
|
init_db()
|
|
|
|
|
|
config = GLAMEfficientNetConfig()
|
|
model = GLAMEfficientNetForClassification(config)
|
|
model.load_state_dict(torch.load('efficientnet_glam_best.pt', map_location='cpu'))
|
|
model.eval()
|
|
|
|
|
|
CLASS_NAMES = ["Advanced", "Early", "Normal"]
|
|
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
|
|
@app.route('/')
|
|
def home():
|
|
"""Check that the API is working."""
|
|
return "Glaucoma Detection Flask API (EfficientNet + GLAM) is running!"
|
|
|
|
|
|
@app.route("/test_file")
|
|
def test_file():
|
|
"""Check if the .pt model file is present and readable."""
|
|
filepath = "efficientnet_glam_best.pt"
|
|
if os.path.exists(filepath):
|
|
return f"β
Model file found at: {filepath}"
|
|
else:
|
|
return "β Model file NOT found."
|
|
|
|
|
|
@app.route('/predict', methods=['POST'])
|
|
def predict():
|
|
"""Perform prediction and save results (including Grad-CAM++) to the database."""
|
|
if 'file' not in request.files:
|
|
return jsonify({'error': 'No file uploaded'}), 400
|
|
|
|
uploaded_file = request.files['file']
|
|
if uploaded_file.filename == '':
|
|
return jsonify({'error': 'No file selected'}), 400
|
|
|
|
try:
|
|
|
|
timestamp = int(datetime.now().timestamp())
|
|
uploaded_filename = f"uploaded_{timestamp}.png"
|
|
uploaded_file_path = os.path.join(OUTPUT_DIR, uploaded_filename)
|
|
uploaded_file.save(uploaded_file_path)
|
|
|
|
|
|
img = Image.open(uploaded_file_path).convert('RGB')
|
|
input_tensor = transform(img).unsqueeze(0)
|
|
|
|
|
|
output = model(input_tensor)
|
|
probabilities = F.softmax(output["logits"], dim=1).cpu().detach().numpy()[0]
|
|
class_index = np.argmax(probabilities)
|
|
result = CLASS_NAMES[class_index]
|
|
confidence = float(probabilities[class_index])
|
|
|
|
|
|
|
|
|
|
target_layer = model.features[-1]
|
|
cam_model = GradCAMPlusPlus(model=model, target_layers=[target_layer])
|
|
|
|
|
|
cam_output = cam_model(input_tensor=input_tensor, targets=[ClassifierOutputTarget(class_index)])[0]
|
|
|
|
|
|
original_img = np.asarray(img.resize((224, 224)), dtype=np.float32) / 255.0
|
|
overlay = show_cam_on_image(original_img, cam_output, use_rgb=True)
|
|
|
|
|
|
cam_normalized = np.uint8(255 * cam_output)
|
|
|
|
|
|
gradcam_filename = f"gradcam_{timestamp}.png"
|
|
gradcam_file_path = os.path.join(OUTPUT_DIR, gradcam_filename)
|
|
cv2.imwrite(gradcam_file_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
|
|
|
|
|
|
gray_filename = f"gradcam_gray_{timestamp}.png"
|
|
gray_file_path = os.path.join(OUTPUT_DIR, gray_filename)
|
|
cv2.imwrite(gray_file_path, cam_normalized)
|
|
|
|
|
|
conn = sqlite3.connect(DB_PATH)
|
|
cursor = conn.cursor()
|
|
cursor.execute("""
|
|
INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, gradcam_gray_filename, timestamp)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""", (uploaded_filename, result, confidence, gradcam_filename, gray_filename, datetime.now().isoformat()))
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
|
|
return jsonify({
|
|
'prediction': result,
|
|
'confidence': confidence,
|
|
'normal_probability': float(probabilities[0]),
|
|
'early_glaucoma_probability': float(probabilities[1]),
|
|
'advanced_glaucoma_probability': float(probabilities[2]),
|
|
'gradcam_image': gradcam_filename,
|
|
'gradcam_gray_image': gray_filename,
|
|
'image_filename': uploaded_filename
|
|
})
|
|
|
|
except Exception as e:
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
|
|
@app.route('/results', methods=['GET'])
|
|
def results():
|
|
"""List all results from the SQLite database."""
|
|
conn = sqlite3.connect(DB_PATH)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT * FROM results ORDER BY timestamp DESC")
|
|
results_data = cursor.fetchall()
|
|
conn.close()
|
|
|
|
results_list = []
|
|
for record in results_data:
|
|
results_list.append({
|
|
'id': record[0],
|
|
'image_filename': record[1],
|
|
'prediction': record[2],
|
|
'confidence': record[3],
|
|
'gradcam_filename': record[4],
|
|
'gradcam_gray_filename': record[5],
|
|
'timestamp': record[6]
|
|
})
|
|
|
|
return jsonify(results_list)
|
|
|
|
|
|
@app.route('/gradcam/<filename>')
|
|
def get_gradcam(filename):
|
|
"""Serve the Grad-CAM overlay image."""
|
|
filepath = os.path.join(OUTPUT_DIR, filename)
|
|
if os.path.exists(filepath):
|
|
return send_file(filepath, mimetype='image/png')
|
|
else:
|
|
return jsonify({'error': 'File not found'}), 404
|
|
|
|
|
|
@app.route('/image/<filename>')
|
|
def get_image(filename):
|
|
"""Serve the original uploaded image."""
|
|
filepath = os.path.join(OUTPUT_DIR, filename)
|
|
if os.path.exists(filepath):
|
|
return send_file(filepath, mimetype='image/png')
|
|
else:
|
|
return jsonify({'error': 'File not found'}), 404
|
|
|
|
|
|
if __name__ == '__main__':
|
|
app.run(host='0.0.0.0', port=7860)
|
|
|