Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing import image as keras_image | |
from tensorflow.keras import backend as K | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import io | |
import cv2 | |
# --- Load model and labels --- | |
model = load_model("checkpoints/keras_model.h5") | |
with open("labels.txt", "r") as f: | |
class_labels = [line.strip() for line in f] | |
# --- Preprocess input --- | |
def preprocess_input(img): | |
img = img.resize((224, 224)) | |
arr = keras_image.img_to_array(img) | |
arr = arr / 255.0 | |
return np.expand_dims(arr, axis=0) | |
# --- Enhanced Grad-CAM implementation for Keras --- | |
def get_gradcam_heatmap(img_array, model, class_index, last_conv_layer_name="conv5_block3_out"): | |
try: | |
# Try to find the specified layer | |
target_layer = model.get_layer(last_conv_layer_name) | |
except: | |
# Fallback: find any convolutional layer | |
for layer in model.layers: | |
if 'conv' in layer.name.lower(): | |
target_layer = layer | |
break | |
else: | |
return None | |
grad_model = tf.keras.models.Model( | |
[model.inputs], [target_layer.output, model.output] | |
) | |
with tf.GradientTape() as tape: | |
conv_outputs, predictions = grad_model(img_array) | |
loss = predictions[:, class_index] | |
grads = tape.gradient(loss, conv_outputs)[0] | |
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) | |
conv_outputs = conv_outputs[0] | |
heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1) | |
heatmap = np.maximum(heatmap, 0) | |
heatmap = heatmap / np.max(heatmap + K.epsilon()) | |
return heatmap.numpy() | |
# --- Enhanced Overlay heatmap on image --- | |
def overlay_gradcam(original_img, heatmap): | |
if heatmap is None: | |
return original_img | |
# Resize heatmap | |
heatmap = cv2.resize(heatmap, original_img.size) | |
# Normalize safely | |
heatmap = np.maximum(heatmap, 0) | |
if np.max(heatmap) != 0: | |
heatmap /= np.max(heatmap) | |
heatmap = np.uint8(255 * heatmap) | |
# Apply JET colormap for better medical visualization | |
heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
# Convert PIL to array | |
original_array = np.array(original_img.convert("RGB")) | |
# Enhanced blend with better contrast | |
superimposed_img = cv2.addWeighted(original_array, 0.6, heatmap_color, 0.4, 0) | |
return Image.fromarray(superimposed_img) | |
# --- Enhanced Prediction Function --- | |
def classify_and_explain(img): | |
if img is None: | |
return None, {}, "No image provided" | |
img_array = preprocess_input(img) | |
predictions = model.predict(img_array, verbose=0)[0] | |
pred_idx = int(np.argmax(predictions)) | |
pred_class = class_labels[pred_idx] | |
confidence_dict = {class_labels[i]: float(predictions[i]) for i in range(len(class_labels))} | |
# Enhanced Grad-CAM | |
try: | |
heatmap = get_gradcam_heatmap(img_array, model, pred_idx) | |
gradcam_img = overlay_gradcam(img.resize((224, 224)), heatmap) | |
except Exception as e: | |
print(f"Grad-CAM error: {e}") | |
gradcam_img = img.resize((224, 224)) # fallback image | |
return gradcam_img, confidence_dict | |
# --- Custom CSS for Dark Mode Medical Interface --- | |
css = """ | |
.gradio-container { | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
background: #1a1a1a; | |
min-height: 100vh; | |
padding: 20px; | |
color: #ffffff; | |
} | |
.main-header { | |
text-align: center; | |
color: white; | |
margin-bottom: 2rem; | |
padding: 2rem 0; | |
} | |
.main-header h1 { | |
font-size: 2.5rem; | |
margin-bottom: 0.5rem; | |
text-shadow: 2px 2px 4px rgba(0,0,0,0.5); | |
color: #ffffff; | |
} | |
.confidence-bar { | |
background: linear-gradient(90deg, #3498db 0%, #2ecc71 100%); | |
height: 25px; | |
border-radius: 12px; | |
margin: 8px 0; | |
transition: all 0.3s ease; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.3); | |
} | |
.confidence-container { | |
margin: 15px 0; | |
padding: 20px; | |
border-radius: 12px; | |
background: rgba(255,255,255,0.1); | |
backdrop-filter: blur(10px); | |
box-shadow: 0 8px 32px rgba(0,0,0,0.3); | |
border: 1px solid rgba(255,255,255,0.1); | |
} | |
.input-section, .output-section { | |
background: rgba(255,255,255,0.05); | |
padding: 25px; | |
border-radius: 15px; | |
margin: 15px; | |
backdrop-filter: blur(10px); | |
box-shadow: 0 8px 32px rgba(0,0,0,0.3); | |
border: 1px solid rgba(255,255,255,0.1); | |
} | |
.section-title { | |
color: #ffffff; | |
font-size: 1.3rem; | |
font-weight: 600; | |
margin-bottom: 15px; | |
border-bottom: 2px solid #3498db; | |
padding-bottom: 8px; | |
} | |
.gradio-button { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
border: none; | |
color: white; | |
padding: 12px 24px; | |
border-radius: 25px; | |
font-weight: 600; | |
transition: all 0.3s ease; | |
box-shadow: 0 4px 15px rgba(0,0,0,0.3); | |
} | |
.gradio-button:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 6px 20px rgba(0,0,0,0.4); | |
} | |
.gradio-image { | |
border-radius: 12px; | |
box-shadow: 0 4px 15px rgba(0,0,0,0.3); | |
border: 1px solid rgba(255,255,255,0.1); | |
} | |
.gradio-textbox, .gradio-number { | |
border-radius: 8px; | |
border: 2px solid #333333; | |
padding: 12px; | |
font-size: 1rem; | |
background: rgba(255,255,255,0.05); | |
color: #ffffff; | |
} | |
.gradio-textbox:focus, .gradio-number:focus { | |
border-color: #3498db; | |
box-shadow: 0 0 0 0.2rem rgba(52,152,219,0.25); | |
} | |
.gradio-label { | |
color: #ffffff !important; | |
} | |
.heatmap-container { | |
background: rgba(255,255,255,0.05); | |
padding: 15px; | |
border-radius: 12px; | |
border: 1px solid rgba(255,255,255,0.1); | |
margin: 10px 0; | |
} | |
.prediction-container { | |
background: rgba(52,152,219,0.1); | |
padding: 20px; | |
border-radius: 12px; | |
border-left: 5px solid #3498db; | |
margin: 15px 0; | |
} | |
""" | |
# --- Function to create confidence bars HTML --- | |
def create_confidence_bars(confidence_dict): | |
html_content = "<div class='confidence-container'>" | |
for class_name, confidence in confidence_dict.items(): | |
percentage = confidence * 100 | |
# Color coding based on confidence | |
if percentage > 70: | |
color = "#28a745" # Green for high confidence | |
elif percentage > 40: | |
color = "#ffc107" # Yellow for medium confidence | |
else: | |
color = "#dc3545" # Red for low confidence | |
html_content += f""" | |
<div style='margin: 12px 0;'> | |
<div style='display: flex; justify-content: space-between; margin-bottom: 8px;'> | |
<span style='font-weight: bold; color: {color};'>{class_name}</span> | |
<span style='font-weight: bold; color: {color};'>{percentage:.1f}%</span> | |
</div> | |
<div class='confidence-bar' style='width: {percentage}%; background: {color};'></div> | |
</div> | |
""" | |
html_content += "</div>" | |
return html_content | |
# --- Enhanced Prediction Function with Dark Mode Interface --- | |
def enhanced_classify_and_explain(img): | |
if img is None: | |
return None, "No image provided", 0, "" | |
gradcam_img, confidence_dict = classify_and_explain(img) | |
# Get predicted class and confidence | |
pred_class = max(confidence_dict, key=confidence_dict.get) | |
confidence = confidence_dict[pred_class] | |
# Create confidence bars HTML | |
confidence_bars_html = create_confidence_bars(confidence_dict) | |
return gradcam_img, pred_class, confidence, confidence_bars_html | |
# --- Enhanced Gradio Interface --- | |
with gr.Blocks(css=css, title="Wound Classification") as demo: | |
gr.HTML(""" | |
<div class="main-header"> | |
<h1>Wound Classification</h1> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.HTML("<div class='section-title'>Input Image</div>") | |
input_image = gr.Image( | |
label="Upload wound image", | |
type="pil", | |
height=350, | |
container=True | |
) | |
with gr.Column(scale=1): | |
gr.HTML("<div class='section-title'>Analysis Results</div>") | |
# Prediction results | |
prediction_output = gr.Textbox( | |
label="Predicted Wound Type", | |
interactive=False, | |
container=True | |
) | |
confidence_output = gr.Number( | |
label="Confidence Score", | |
interactive=False, | |
container=True | |
) | |
# Confidence bars for all classes | |
confidence_bars = gr.HTML( | |
label="Confidence Scores by Class", | |
container=True | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML("<div class='section-title'>Model Focus Visualization</div>") | |
cam_output = gr.Image( | |
label="Grad-CAM Heatmap - Shows which areas the model focused on", | |
height=350, | |
container=True | |
) | |
# Event handlers | |
input_image.change( | |
fn=enhanced_classify_and_explain, | |
inputs=[input_image], | |
outputs=[cam_output, prediction_output, confidence_output, confidence_bars] | |
) | |
# --- Launch the enhanced interface --- | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, | |
show_error=True | |
) |