dhruv2842's picture
Update app.py
3c9ef9c verified
raw
history blame
6.37 kB
from flask import Flask, request, jsonify, send_file
from tensorflow.keras.models import load_model, Model
from PIL import Image
import numpy as np
import os
import cv2
import tensorflow as tf
from datetime import datetime
import sqlite3
app = Flask(__name__)
# βœ… Directory and database path
OUTPUT_DIR = '/tmp/results'
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
DB_PATH = os.path.join(OUTPUT_DIR, 'results.db')
def init_db():
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_filename TEXT,
prediction TEXT,
confidence REAL,
gradcam_filename TEXT,
timestamp TEXT
)
""")
conn.commit()
conn.close()
init_db()
# βœ… Load Model
model = load_model('mobilenet_glaucoma_model.h5', compile=False)
# βœ… Preprocess Image
def preprocess_image(img):
img = img.resize((224, 224))
img = np.array(img) / 255.0
img = np.expand_dims(img, axis=0)
return img
# βœ… Grad-CAM Generation
def make_gradcam(img_array, model, last_conv_layer_name='Conv_1_bn'):
"""Generate Grad-CAM for the given image and model."""
last_conv_layer = model.get_layer(last_conv_layer_name)
grad_model = Model(inputs=model.inputs, outputs=[last_conv_layer.output, model.output])
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(img_array)
loss = predictions[:, 0]
grads = tape.gradient(loss, conv_outputs)
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
conv_outputs = conv_outputs[0].numpy()
pooled_grads = pooled_grads.numpy()
for i in range(conv_outputs.shape[-1]):
conv_outputs[..., i] *= pooled_grads[i]
heatmap = tf.reduce_mean(conv_outputs, axis=-1).numpy()
heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
return heatmap
# βœ… Save Grad-CAM Overlay
def save_gradcam_image(original_img, heatmap, filename='gradcam.png', output_dir=OUTPUT_DIR):
"""Save the Grad-CAM overlay image and return its path."""
img = np.array(original_img.resize((224, 224)))
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
overlay = cv2.addWeighted(img, 0.6, heatmap, 0.4, 0)
filepath = os.path.join(output_dir, filename)
cv2.imwrite(filepath, overlay)
return filepath
@app.route('/')
def home():
return "Glaucoma Detection Flask API is running!"
@app.route("/test_file")
def test_file():
"""Check if the model file is present and readable."""
filepath = "mobilenet_glaucoma_model.h5"
if os.path.exists(filepath):
return f"βœ… Model file found at: {filepath}"
else:
return "❌ Model file NOT found."
@app.route('/predict', methods=['POST'])
def predict():
"""Perform prediction, save results (including uploaded image), and save to SQLite database."""
if 'file' not in request.files:
return jsonify({'error': 'No file uploaded'}), 400
uploaded_file = request.files['file']
if uploaded_file.filename == '':
return jsonify({'error': 'No file selected'}), 400
try:
# βœ… Save the uploaded image
timestamp = int(datetime.now().timestamp())
uploaded_filename = f"uploaded_{timestamp}.png"
uploaded_file_path = os.path.join(OUTPUT_DIR, uploaded_filename)
uploaded_file.save(uploaded_file_path)
# βœ… Perform prediction
img = Image.open(uploaded_file_path).convert('RGB')
img_array = preprocess_image(img)
prediction = model.predict(img_array)[0]
glaucoma_prob = 1 - prediction[0]
normal_prob = prediction[0]
result = 'Glaucoma' if glaucoma_prob > normal_prob else 'Normal'
confidence = float(glaucoma_prob) if result == 'Glaucoma' else float(normal_prob)
# βœ… Grad-CAM
heatmap = make_gradcam(img_array, model, last_conv_layer_name='Conv_1_bn')
gradcam_filename = f"gradcam_{timestamp}.png"
save_gradcam_image(img, heatmap, filename=gradcam_filename)
# βœ… Save results to SQLite
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute("""
INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, timestamp)
VALUES (?, ?, ?, ?, ?)
""", (uploaded_filename, result, confidence, gradcam_filename, datetime.now().isoformat()))
conn.commit()
conn.close()
return jsonify({
'prediction': result,
'confidence': confidence,
'normal_probability': float(normal_prob),
'glaucoma_probability': float(glaucoma_prob),
'gradcam_image': gradcam_filename,
'image_filename': uploaded_filename
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/results', methods=['GET'])
def results():
"""List all results from the SQLite database."""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute("SELECT * FROM results ORDER BY timestamp DESC")
results_data = cursor.fetchall()
conn.close()
results_list = []
for record in results_data:
results_list.append({
'id': record[0],
'image_filename': record[1],
'prediction': record[2],
'confidence': record[3],
'gradcam_filename': record[4],
'timestamp': record[5]
})
return jsonify(results_list)
@app.route('/gradcam/<filename>')
def get_gradcam(filename):
"""Serve the Grad-CAM overlay image."""
filepath = os.path.join(OUTPUT_DIR, filename)
if os.path.exists(filepath):
return send_file(filepath, mimetype='image/png')
else:
return jsonify({'error': 'File not found'}), 404
@app.route('/image/<filename>')
def get_image(filename):
"""Serve the original uploaded image."""
filepath = os.path.join(OUTPUT_DIR, filename)
if os.path.exists(filepath):
return send_file(filepath, mimetype='image/png')
else:
return jsonify({'error': 'File not found'}), 404
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)