import gradio as gr import numpy as np import tensorflow as tf from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing import image as keras_image from tensorflow.keras import backend as K import matplotlib.pyplot as plt from PIL import Image import io import cv2 # --- Load model and labels --- model = load_model("checkpoints/keras_model.h5") with open("labels.txt", "r") as f: class_labels = [line.strip() for line in f] # --- Preprocess input --- def preprocess_input(img): img = img.resize((224, 224)) arr = keras_image.img_to_array(img) arr = arr / 255.0 return np.expand_dims(arr, axis=0) # --- Enhanced Grad-CAM implementation for Keras --- def get_gradcam_heatmap(img_array, model, class_index, last_conv_layer_name="conv5_block3_out"): try: # Try to find the specified layer target_layer = model.get_layer(last_conv_layer_name) except: # Fallback: find any convolutional layer for layer in model.layers: if 'conv' in layer.name.lower(): target_layer = layer break else: return None grad_model = tf.keras.models.Model( [model.inputs], [target_layer.output, model.output] ) with tf.GradientTape() as tape: conv_outputs, predictions = grad_model(img_array) loss = predictions[:, class_index] grads = tape.gradient(loss, conv_outputs)[0] pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) conv_outputs = conv_outputs[0] heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1) heatmap = np.maximum(heatmap, 0) heatmap = heatmap / np.max(heatmap + K.epsilon()) return heatmap.numpy() # --- Enhanced Overlay heatmap on image --- def overlay_gradcam(original_img, heatmap): if heatmap is None: return original_img # Resize heatmap heatmap = cv2.resize(heatmap, original_img.size) # Normalize safely heatmap = np.maximum(heatmap, 0) if np.max(heatmap) != 0: heatmap /= np.max(heatmap) heatmap = np.uint8(255 * heatmap) # Apply JET colormap for better medical visualization heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # Convert PIL to array original_array = np.array(original_img.convert("RGB")) # Enhanced blend with better contrast superimposed_img = cv2.addWeighted(original_array, 0.6, heatmap_color, 0.4, 0) return Image.fromarray(superimposed_img) # --- Enhanced Prediction Function --- def classify_and_explain(img): if img is None: return None, {}, "No image provided" img_array = preprocess_input(img) predictions = model.predict(img_array, verbose=0)[0] pred_idx = int(np.argmax(predictions)) pred_class = class_labels[pred_idx] confidence_dict = {class_labels[i]: float(predictions[i]) for i in range(len(class_labels))} # Enhanced Grad-CAM try: heatmap = get_gradcam_heatmap(img_array, model, pred_idx) gradcam_img = overlay_gradcam(img.resize((224, 224)), heatmap) except Exception as e: print(f"Grad-CAM error: {e}") gradcam_img = img.resize((224, 224)) # fallback image return gradcam_img, confidence_dict # --- Custom CSS for Dark Mode Medical Interface --- css = """ .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background: #1a1a1a; min-height: 100vh; padding: 20px; color: #ffffff; } .main-header { text-align: center; color: white; margin-bottom: 2rem; padding: 2rem 0; } .main-header h1 { font-size: 2.5rem; margin-bottom: 0.5rem; text-shadow: 2px 2px 4px rgba(0,0,0,0.5); color: #ffffff; } .confidence-bar { background: linear-gradient(90deg, #3498db 0%, #2ecc71 100%); height: 25px; border-radius: 12px; margin: 8px 0; transition: all 0.3s ease; box-shadow: 0 2px 4px rgba(0,0,0,0.3); } .confidence-container { margin: 15px 0; padding: 20px; border-radius: 12px; background: rgba(255,255,255,0.1); backdrop-filter: blur(10px); box-shadow: 0 8px 32px rgba(0,0,0,0.3); border: 1px solid rgba(255,255,255,0.1); } .input-section, .output-section { background: rgba(255,255,255,0.05); padding: 25px; border-radius: 15px; margin: 15px; backdrop-filter: blur(10px); box-shadow: 0 8px 32px rgba(0,0,0,0.3); border: 1px solid rgba(255,255,255,0.1); } .section-title { color: #ffffff; font-size: 1.3rem; font-weight: 600; margin-bottom: 15px; border-bottom: 2px solid #3498db; padding-bottom: 8px; } .gradio-button { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border: none; color: white; padding: 12px 24px; border-radius: 25px; font-weight: 600; transition: all 0.3s ease; box-shadow: 0 4px 15px rgba(0,0,0,0.3); } .gradio-button:hover { transform: translateY(-2px); box-shadow: 0 6px 20px rgba(0,0,0,0.4); } .gradio-image { border-radius: 12px; box-shadow: 0 4px 15px rgba(0,0,0,0.3); border: 1px solid rgba(255,255,255,0.1); } .gradio-textbox, .gradio-number { border-radius: 8px; border: 2px solid #333333; padding: 12px; font-size: 1rem; background: rgba(255,255,255,0.05); color: #ffffff; } .gradio-textbox:focus, .gradio-number:focus { border-color: #3498db; box-shadow: 0 0 0 0.2rem rgba(52,152,219,0.25); } .gradio-label { color: #ffffff !important; } .heatmap-container { background: rgba(255,255,255,0.05); padding: 15px; border-radius: 12px; border: 1px solid rgba(255,255,255,0.1); margin: 10px 0; } .prediction-container { background: rgba(52,152,219,0.1); padding: 20px; border-radius: 12px; border-left: 5px solid #3498db; margin: 15px 0; } """ # --- Function to create confidence bars HTML --- def create_confidence_bars(confidence_dict): html_content = "
" for class_name, confidence in confidence_dict.items(): percentage = confidence * 100 # Color coding based on confidence if percentage > 70: color = "#28a745" # Green for high confidence elif percentage > 40: color = "#ffc107" # Yellow for medium confidence else: color = "#dc3545" # Red for low confidence html_content += f"""
{class_name} {percentage:.1f}%
""" html_content += "
" return html_content # --- Enhanced Prediction Function with Dark Mode Interface --- def enhanced_classify_and_explain(img): if img is None: return None, "No image provided", 0, "" gradcam_img, confidence_dict = classify_and_explain(img) # Get predicted class and confidence pred_class = max(confidence_dict, key=confidence_dict.get) confidence = confidence_dict[pred_class] # Create confidence bars HTML confidence_bars_html = create_confidence_bars(confidence_dict) return gradcam_img, pred_class, confidence, confidence_bars_html # --- Enhanced Gradio Interface --- with gr.Blocks(css=css, title="Wound Classification") as demo: gr.HTML("""

Wound Classification

""") with gr.Row(): with gr.Column(scale=1): gr.HTML("
Input Image
") input_image = gr.Image( label="Upload wound image", type="pil", height=350, container=True ) with gr.Column(scale=1): gr.HTML("
Analysis Results
") # Prediction results prediction_output = gr.Textbox( label="Predicted Wound Type", interactive=False, container=True ) confidence_output = gr.Number( label="Confidence Score", interactive=False, container=True ) # Confidence bars for all classes confidence_bars = gr.HTML( label="Confidence Scores by Class", container=True ) with gr.Row(): with gr.Column(): gr.HTML("
Model Focus Visualization
") cam_output = gr.Image( label="Grad-CAM Heatmap - Shows which areas the model focused on", height=350, container=True ) # Event handlers input_image.change( fn=enhanced_classify_and_explain, inputs=[input_image], outputs=[cam_output, prediction_output, confidence_output, confidence_bars] ) # --- Launch the enhanced interface --- if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=True, show_error=True )