import gradio as gr import numpy as np import tensorflow as tf from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing import image as keras_image from tensorflow.keras import backend as K import matplotlib.pyplot as plt from PIL import Image import io import cv2 # --- Load model and labels --- model = load_model("checkpoints/keras_model.h5") with open("labels.txt", "r") as f: class_labels = [line.strip() for line in f] # --- Preprocess input --- def preprocess_input(img): img = img.resize((224, 224)) arr = keras_image.img_to_array(img) arr = arr / 255.0 return np.expand_dims(arr, axis=0) # --- Enhanced Grad-CAM implementation for Keras --- def get_gradcam_heatmap(img_array, model, class_index, last_conv_layer_name="conv5_block3_out"): try: # Try to find the specified layer target_layer = model.get_layer(last_conv_layer_name) except: # Fallback: find any convolutional layer for layer in model.layers: if 'conv' in layer.name.lower(): target_layer = layer break else: return None grad_model = tf.keras.models.Model( [model.inputs], [target_layer.output, model.output] ) with tf.GradientTape() as tape: conv_outputs, predictions = grad_model(img_array) loss = predictions[:, class_index] grads = tape.gradient(loss, conv_outputs)[0] pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) conv_outputs = conv_outputs[0] heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1) heatmap = np.maximum(heatmap, 0) heatmap = heatmap / np.max(heatmap + K.epsilon()) return heatmap.numpy() # --- Enhanced Overlay heatmap on image --- def overlay_gradcam(original_img, heatmap): if heatmap is None: return original_img # Resize heatmap heatmap = cv2.resize(heatmap, original_img.size) # Normalize safely heatmap = np.maximum(heatmap, 0) if np.max(heatmap) != 0: heatmap /= np.max(heatmap) heatmap = np.uint8(255 * heatmap) # Apply JET colormap for better medical visualization heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # Convert PIL to array original_array = np.array(original_img.convert("RGB")) # Enhanced blend with better contrast superimposed_img = cv2.addWeighted(original_array, 0.6, heatmap_color, 0.4, 0) return Image.fromarray(superimposed_img) # --- Enhanced Prediction Function --- def classify_and_explain(img): if img is None: return None, {}, "No image provided" img_array = preprocess_input(img) predictions = model.predict(img_array, verbose=0)[0] pred_idx = int(np.argmax(predictions)) pred_class = class_labels[pred_idx] confidence_dict = {class_labels[i]: float(predictions[i]) for i in range(len(class_labels))} # Enhanced Grad-CAM try: heatmap = get_gradcam_heatmap(img_array, model, pred_idx) gradcam_img = overlay_gradcam(img.resize((224, 224)), heatmap) except Exception as e: print(f"Grad-CAM error: {e}") gradcam_img = img.resize((224, 224)) # fallback image return gradcam_img, confidence_dict # --- Custom CSS for Dark Mode Medical Interface --- css = """ .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background: #1a1a1a; min-height: 100vh; padding: 20px; color: #ffffff; } .main-header { text-align: center; color: white; margin-bottom: 2rem; padding: 2rem 0; } .main-header h1 { font-size: 2.5rem; margin-bottom: 0.5rem; text-shadow: 2px 2px 4px rgba(0,0,0,0.5); color: #ffffff; } .confidence-bar { background: linear-gradient(90deg, #3498db 0%, #2ecc71 100%); height: 25px; border-radius: 12px; margin: 8px 0; transition: all 0.3s ease; box-shadow: 0 2px 4px rgba(0,0,0,0.3); } .confidence-container { margin: 15px 0; padding: 20px; border-radius: 12px; background: rgba(255,255,255,0.1); backdrop-filter: blur(10px); box-shadow: 0 8px 32px rgba(0,0,0,0.3); border: 1px solid rgba(255,255,255,0.1); } .input-section, .output-section { background: rgba(255,255,255,0.05); padding: 25px; border-radius: 15px; margin: 15px; backdrop-filter: blur(10px); box-shadow: 0 8px 32px rgba(0,0,0,0.3); border: 1px solid rgba(255,255,255,0.1); } .section-title { color: #ffffff; font-size: 1.3rem; font-weight: 600; margin-bottom: 15px; border-bottom: 2px solid #3498db; padding-bottom: 8px; } .gradio-button { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border: none; color: white; padding: 12px 24px; border-radius: 25px; font-weight: 600; transition: all 0.3s ease; box-shadow: 0 4px 15px rgba(0,0,0,0.3); } .gradio-button:hover { transform: translateY(-2px); box-shadow: 0 6px 20px rgba(0,0,0,0.4); } .gradio-image { border-radius: 12px; box-shadow: 0 4px 15px rgba(0,0,0,0.3); border: 1px solid rgba(255,255,255,0.1); } .gradio-textbox, .gradio-number { border-radius: 8px; border: 2px solid #333333; padding: 12px; font-size: 1rem; background: rgba(255,255,255,0.05); color: #ffffff; } .gradio-textbox:focus, .gradio-number:focus { border-color: #3498db; box-shadow: 0 0 0 0.2rem rgba(52,152,219,0.25); } .gradio-label { color: #ffffff !important; } .heatmap-container { background: rgba(255,255,255,0.05); padding: 15px; border-radius: 12px; border: 1px solid rgba(255,255,255,0.1); margin: 10px 0; } .prediction-container { background: rgba(52,152,219,0.1); padding: 20px; border-radius: 12px; border-left: 5px solid #3498db; margin: 15px 0; } """ # --- Function to create confidence bars HTML --- def create_confidence_bars(confidence_dict): html_content = "