gradui-1 / app2.py
viranchi123's picture
Upload 75 files
ca46f55 verified
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt
from PIL import Image
import io
import cv2
# --- Load model and labels ---
model = load_model("checkpoints/keras_model.h5")
with open("labels.txt", "r") as f:
class_labels = [line.strip() for line in f]
# --- Preprocess input ---
def preprocess_input(img):
img = img.resize((224, 224))
arr = keras_image.img_to_array(img)
arr = arr / 255.0
return np.expand_dims(arr, axis=0)
# --- Enhanced Grad-CAM implementation for Keras ---
def get_gradcam_heatmap(img_array, model, class_index, last_conv_layer_name="conv5_block3_out"):
try:
# Try to find the specified layer
target_layer = model.get_layer(last_conv_layer_name)
except:
# Fallback: find any convolutional layer
for layer in model.layers:
if 'conv' in layer.name.lower():
target_layer = layer
break
else:
return None
grad_model = tf.keras.models.Model(
[model.inputs], [target_layer.output, model.output]
)
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(img_array)
loss = predictions[:, class_index]
grads = tape.gradient(loss, conv_outputs)[0]
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
conv_outputs = conv_outputs[0]
heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1)
heatmap = np.maximum(heatmap, 0)
heatmap = heatmap / np.max(heatmap + K.epsilon())
return heatmap.numpy()
# --- Enhanced Overlay heatmap on image ---
def overlay_gradcam(original_img, heatmap):
if heatmap is None:
return original_img
# Resize heatmap
heatmap = cv2.resize(heatmap, original_img.size)
# Normalize safely
heatmap = np.maximum(heatmap, 0)
if np.max(heatmap) != 0:
heatmap /= np.max(heatmap)
heatmap = np.uint8(255 * heatmap)
# Apply JET colormap for better medical visualization
heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# Convert PIL to array
original_array = np.array(original_img.convert("RGB"))
# Enhanced blend with better contrast
superimposed_img = cv2.addWeighted(original_array, 0.6, heatmap_color, 0.4, 0)
return Image.fromarray(superimposed_img)
# --- Enhanced Prediction Function ---
def classify_and_explain(img):
if img is None:
return None, {}, "No image provided"
img_array = preprocess_input(img)
predictions = model.predict(img_array, verbose=0)[0]
pred_idx = int(np.argmax(predictions))
pred_class = class_labels[pred_idx]
confidence_dict = {class_labels[i]: float(predictions[i]) for i in range(len(class_labels))}
# Enhanced Grad-CAM
try:
heatmap = get_gradcam_heatmap(img_array, model, pred_idx)
gradcam_img = overlay_gradcam(img.resize((224, 224)), heatmap)
except Exception as e:
print(f"Grad-CAM error: {e}")
gradcam_img = img.resize((224, 224)) # fallback image
return gradcam_img, confidence_dict
# --- Custom CSS for Dark Mode Medical Interface ---
css = """
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: #1a1a1a;
min-height: 100vh;
padding: 20px;
color: #ffffff;
}
.main-header {
text-align: center;
color: white;
margin-bottom: 2rem;
padding: 2rem 0;
}
.main-header h1 {
font-size: 2.5rem;
margin-bottom: 0.5rem;
text-shadow: 2px 2px 4px rgba(0,0,0,0.5);
color: #ffffff;
}
.confidence-bar {
background: linear-gradient(90deg, #3498db 0%, #2ecc71 100%);
height: 25px;
border-radius: 12px;
margin: 8px 0;
transition: all 0.3s ease;
box-shadow: 0 2px 4px rgba(0,0,0,0.3);
}
.confidence-container {
margin: 15px 0;
padding: 20px;
border-radius: 12px;
background: rgba(255,255,255,0.1);
backdrop-filter: blur(10px);
box-shadow: 0 8px 32px rgba(0,0,0,0.3);
border: 1px solid rgba(255,255,255,0.1);
}
.input-section, .output-section {
background: rgba(255,255,255,0.05);
padding: 25px;
border-radius: 15px;
margin: 15px;
backdrop-filter: blur(10px);
box-shadow: 0 8px 32px rgba(0,0,0,0.3);
border: 1px solid rgba(255,255,255,0.1);
}
.section-title {
color: #ffffff;
font-size: 1.3rem;
font-weight: 600;
margin-bottom: 15px;
border-bottom: 2px solid #3498db;
padding-bottom: 8px;
}
.gradio-button {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: none;
color: white;
padding: 12px 24px;
border-radius: 25px;
font-weight: 600;
transition: all 0.3s ease;
box-shadow: 0 4px 15px rgba(0,0,0,0.3);
}
.gradio-button:hover {
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(0,0,0,0.4);
}
.gradio-image {
border-radius: 12px;
box-shadow: 0 4px 15px rgba(0,0,0,0.3);
border: 1px solid rgba(255,255,255,0.1);
}
.gradio-textbox, .gradio-number {
border-radius: 8px;
border: 2px solid #333333;
padding: 12px;
font-size: 1rem;
background: rgba(255,255,255,0.05);
color: #ffffff;
}
.gradio-textbox:focus, .gradio-number:focus {
border-color: #3498db;
box-shadow: 0 0 0 0.2rem rgba(52,152,219,0.25);
}
.gradio-label {
color: #ffffff !important;
}
.heatmap-container {
background: rgba(255,255,255,0.05);
padding: 15px;
border-radius: 12px;
border: 1px solid rgba(255,255,255,0.1);
margin: 10px 0;
}
.prediction-container {
background: rgba(52,152,219,0.1);
padding: 20px;
border-radius: 12px;
border-left: 5px solid #3498db;
margin: 15px 0;
}
"""
# --- Function to create confidence bars HTML ---
def create_confidence_bars(confidence_dict):
html_content = "<div class='confidence-container'>"
for class_name, confidence in confidence_dict.items():
percentage = confidence * 100
# Color coding based on confidence
if percentage > 70:
color = "#28a745" # Green for high confidence
elif percentage > 40:
color = "#ffc107" # Yellow for medium confidence
else:
color = "#dc3545" # Red for low confidence
html_content += f"""
<div style='margin: 12px 0;'>
<div style='display: flex; justify-content: space-between; margin-bottom: 8px;'>
<span style='font-weight: bold; color: {color};'>{class_name}</span>
<span style='font-weight: bold; color: {color};'>{percentage:.1f}%</span>
</div>
<div class='confidence-bar' style='width: {percentage}%; background: {color};'></div>
</div>
"""
html_content += "</div>"
return html_content
# --- Enhanced Prediction Function with Dark Mode Interface ---
def enhanced_classify_and_explain(img):
if img is None:
return None, "No image provided", 0, ""
gradcam_img, confidence_dict = classify_and_explain(img)
# Get predicted class and confidence
pred_class = max(confidence_dict, key=confidence_dict.get)
confidence = confidence_dict[pred_class]
# Create confidence bars HTML
confidence_bars_html = create_confidence_bars(confidence_dict)
return gradcam_img, pred_class, confidence, confidence_bars_html
# --- Enhanced Gradio Interface ---
with gr.Blocks(css=css, title="Wound Classification") as demo:
gr.HTML("""
<div class="main-header">
<h1>Wound Classification</h1>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<div class='section-title'>Input Image</div>")
input_image = gr.Image(
label="Upload wound image",
type="pil",
height=350,
container=True
)
with gr.Column(scale=1):
gr.HTML("<div class='section-title'>Analysis Results</div>")
# Prediction results
prediction_output = gr.Textbox(
label="Predicted Wound Type",
interactive=False,
container=True
)
confidence_output = gr.Number(
label="Confidence Score",
interactive=False,
container=True
)
# Confidence bars for all classes
confidence_bars = gr.HTML(
label="Confidence Scores by Class",
container=True
)
with gr.Row():
with gr.Column():
gr.HTML("<div class='section-title'>Model Focus Visualization</div>")
cam_output = gr.Image(
label="Grad-CAM Heatmap - Shows which areas the model focused on",
height=350,
container=True
)
# Event handlers
input_image.change(
fn=enhanced_classify_and_explain,
inputs=[input_image],
outputs=[cam_output, prediction_output, confidence_output, confidence_bars]
)
# --- Launch the enhanced interface ---
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)