import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt
from PIL import Image
import io
import cv2
import glob
import matplotlib
import torch
import tempfile
from gradio_imageslider import ImageSlider
import plotly.graph_objects as go
import plotly.express as px
import open3d as o3d
from depth_anything_v2.dpt import DepthAnythingV2
# --- Load models ---
# Wound classification model
try:
wound_model = load_model("checkpoints/keras_model.h5")
with open("labels.txt", "r") as f:
class_labels = [line.strip() for line in f]
except:
wound_model = None
class_labels = ["No model found"]
# Depth estimation model
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
encoder = 'vitl'
try:
depth_model = DepthAnythingV2(**model_configs[encoder])
state_dict = torch.load(f'checkpoints/depth_anything_v2_{encoder}.pth', map_location="cpu")
depth_model.load_state_dict(state_dict)
depth_model = depth_model.to(DEVICE).eval()
except:
depth_model = None
# --- Wound Classification Functions ---
def preprocess_input(img):
img = img.resize((224, 224))
arr = keras_image.img_to_array(img)
arr = arr / 255.0
return np.expand_dims(arr, axis=0)
def get_gradcam_heatmap(img_array, model, class_index, last_conv_layer_name="conv5_block3_out"):
try:
target_layer = model.get_layer(last_conv_layer_name)
except:
for layer in model.layers:
if 'conv' in layer.name.lower():
target_layer = layer
break
else:
return None
grad_model = tf.keras.models.Model(
[model.inputs], [target_layer.output, model.output]
)
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(img_array)
loss = predictions[:, class_index]
grads = tape.gradient(loss, conv_outputs)
if grads is None:
return None
grads = grads[0]
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
conv_outputs = conv_outputs[0]
heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1)
heatmap = np.maximum(heatmap, 0)
heatmap = heatmap / np.max(heatmap + K.epsilon())
return heatmap.numpy()
def overlay_gradcam(original_img, heatmap):
if heatmap is None:
return original_img
heatmap = cv2.resize(heatmap, original_img.size)
heatmap = np.maximum(heatmap, 0)
if np.max(heatmap) != 0:
heatmap /= np.max(heatmap)
heatmap = np.uint8(255 * heatmap)
heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
original_array = np.array(original_img.convert("RGB"))
superimposed_img = cv2.addWeighted(original_array, 0.6, heatmap_color, 0.4, 0)
return Image.fromarray(superimposed_img)
def classify_and_explain(img):
if img is None or wound_model is None:
return None, {}, "No image provided or model not available"
img_array = preprocess_input(img)
predictions = wound_model.predict(img_array, verbose=0)[0]
pred_idx = int(np.argmax(predictions))
pred_class = class_labels[pred_idx]
confidence_dict = {class_labels[i]: float(predictions[i]) for i in range(len(class_labels))}
try:
heatmap = get_gradcam_heatmap(img_array, wound_model, pred_idx)
gradcam_img = overlay_gradcam(img.resize((224, 224)), heatmap)
except Exception as e:
print(f"Grad-CAM error: {e}")
gradcam_img = img.resize((224, 224))
return gradcam_img, confidence_dict
def create_confidence_bars(confidence_dict):
html_content = "
"
for class_name, confidence in confidence_dict.items():
percentage = confidence * 100
if percentage > 70:
css_class = "confidence-high"
elif percentage > 40:
css_class = "confidence-medium"
else:
css_class = "confidence-low"
html_content += f"""
{class_name}
{percentage:.1f}%
"""
html_content += "
"
return html_content
def enhanced_classify_and_explain(img):
if img is None:
return None, "No image provided", 0, ""
gradcam_img, confidence_dict = classify_and_explain(img)
if isinstance(confidence_dict, str): # Error case
return None, confidence_dict, 0, ""
pred_class = max(confidence_dict, key=confidence_dict.get)
confidence = confidence_dict[pred_class]
confidence_bars_html = create_confidence_bars(confidence_dict)
return gradcam_img, pred_class, confidence, confidence_bars_html
# --- Depth Estimation Functions ---
def predict_depth(image):
if depth_model is None:
return None
return depth_model.infer_image(image)
def calculate_max_points(image):
if image is None:
return 10000
h, w = image.shape[:2]
max_points = h * w * 3
return max(1000, min(max_points, 1000000))
def update_slider_on_image_upload(image):
max_points = calculate_max_points(image)
default_value = min(10000, max_points // 10)
return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
label=f"Number of 3D points (max: {max_points:,})")
def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=100000):
h, w = depth_map.shape
step = max(1, int(np.sqrt(h * w / max_points)))
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
x_cam = (x_coords - w / 2) / focal_length_x
y_cam = (y_coords - h / 2) / focal_length_y
depth_values = depth_map[::step, ::step]
x_3d = x_cam * depth_values
y_3d = y_cam * depth_values
z_3d = depth_values
points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
image_colors = image[::step, ::step, :]
colors = image_colors.reshape(-1, 3) / 255.0
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
pcd.colors = o3d.utility.Vector3dVector(colors)
return pcd
def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
h, w = depth_map.shape
step = max(1, int(np.sqrt(h * w / max_points)))
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
focal_length = 470.4
x_cam = (x_coords - w / 2) / focal_length
y_cam = (y_coords - h / 2) / focal_length
depth_values = depth_map[::step, ::step]
x_3d = x_cam * depth_values
y_3d = y_cam * depth_values
z_3d = depth_values
x_flat = x_3d.flatten()
y_flat = y_3d.flatten()
z_flat = z_3d.flatten()
image_colors = image[::step, ::step, :]
colors_flat = image_colors.reshape(-1, 3)
fig = go.Figure(data=[go.Scatter3d(
x=x_flat,
y=y_flat,
z=z_flat,
mode='markers',
marker=dict(
size=1.5,
color=colors_flat,
opacity=0.9
),
hovertemplate='3D Position: (%{x:.3f}, %{y:.3f}, %{z:.3f})
' +
'Depth: %{z:.2f}
' +
''
)])
fig.update_layout(
title="3D Point Cloud Visualization (Camera Projection)",
scene=dict(
xaxis_title="X (meters)",
yaxis_title="Y (meters)",
zaxis_title="Z (meters)",
camera=dict(
eye=dict(x=2.0, y=2.0, z=2.0),
center=dict(x=0, y=0, z=0),
up=dict(x=0, y=0, z=1)
),
aspectmode='data'
),
width=700,
height=600
)
return fig
def on_depth_submit(image, num_points, focal_x, focal_y):
if image is None or depth_model is None:
return None, None, None, None, None
original_image = image.copy()
h, w = image.shape[:2]
depth = predict_depth(image[:, :, ::-1])
if depth is None:
return None, None, None, None, None
raw_depth = Image.fromarray(depth.astype('uint16'))
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
raw_depth.save(tmp_raw_depth.name)
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.astype(np.uint8)
cmap = matplotlib.colormaps.get_cmap('Spectral_r')
colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
gray_depth = Image.fromarray(depth)
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
gray_depth.save(tmp_gray_depth.name)
pcd = create_point_cloud(original_image, depth, focal_x, focal_y, max_points=num_points)
tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
o3d.io.write_point_cloud(tmp_pointcloud.name, pcd)
depth_3d = create_enhanced_3d_visualization(original_image, depth, max_points=num_points)
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
# --- Custom CSS for Unified Interface ---
css = """
/* Minimal dark theme styling */
.main-header {
text-align: center;
margin-bottom: 2rem;
padding: 2rem 0;
}
.main-header h1 {
font-size: 2.5rem;
margin-bottom: 0.5rem;
font-weight: 600;
}
.main-header p {
font-size: 1.1rem;
opacity: 0.8;
}
.section-title {
font-size: 1.2rem;
font-weight: 600;
margin-bottom: 15px;
padding-bottom: 8px;
border-bottom: 1px solid var(--border-color-primary);
}
.confidence-container {
margin: 15px 0;
padding: 15px;
border-radius: 8px;
background: var(--background-secondary);
border: 1px solid var(--border-color-primary);
}
.confidence-bar {
height: 20px;
border-radius: 4px;
margin: 6px 0;
background: var(--primary-500);
transition: width 0.3s ease;
}
/* Simple confidence bar colors */
.confidence-high {
background: var(--success-500);
}
.confidence-medium {
background: var(--warning-500);
}
.confidence-low {
background: var(--error-500);
}
/* Minimal spacing and layout */
.gradio-container {
max-width: 100%;
margin: 0;
padding: 20px;
width: 100%;
}
/* Clean image styling */
.gradio-image {
border-radius: 8px;
border: 1px solid var(--border-color-primary);
}
/* Simple button styling */
.gradio-button {
border-radius: 6px;
font-weight: 500;
}
/* Clean form elements */
.gradio-textbox, .gradio-number, .gradio-slider {
border-radius: 6px;
border: 1px solid var(--border-color-primary);
}
/* Tab styling */
.gradio-tabs {
border-radius: 8px;
overflow: hidden;
}
/* File upload styling */
.gradio-file {
border-radius: 6px;
border: 1px solid var(--border-color-primary);
}
/* Plot styling */
.gradio-plot {
border-radius: 8px;
border: 1px solid var(--border-color-primary);
}
/* Full width and height layout */
body, html {
margin: 0;
padding: 0;
width: 100%;
height: 100%;
}
#root {
width: 100%;
height: 100%;
}
/* Ensure Gradio uses full width */
.gradio-container {
min-height: 100vh;
}
/* Responsive adjustments */
@media (max-width: 768px) {
.main-header h1 {
font-size: 2rem;
}
.gradio-container {
padding: 10px;
}
}
"""
# --- Create Unified Interface ---
with gr.Blocks(css=css, title="Medical AI Suite") as demo:
gr.HTML("""
Medical AI Suite
Advanced AI-powered medical image analysis and 3D visualization
""")
with gr.Tabs() as tabs:
# Tab 1: Wound Classification
with gr.TabItem("Wound Classification", id=0):
gr.HTML("Wound Classification with Grad-CAM Visualization
")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("Input Image
")
wound_input_image = gr.Image(
label="Upload wound image",
type="pil",
height=350,
container=True
)
with gr.Column(scale=1):
gr.HTML("Analysis Results
")
wound_prediction_output = gr.Textbox(
label="Predicted Wound Type",
interactive=False,
container=True
)
wound_confidence_output = gr.Number(
label="Confidence Score",
interactive=False,
container=True
)
wound_confidence_bars = gr.HTML(
label="Confidence Scores by Class",
container=True
)
with gr.Row():
with gr.Column():
gr.HTML("Model Focus Visualization
")
wound_cam_output = gr.Image(
label="Grad-CAM Heatmap - Shows which areas the model focused on",
height=350,
container=True
)
# Event handlers for wound classification
wound_input_image.change(
fn=enhanced_classify_and_explain,
inputs=[wound_input_image],
outputs=[wound_cam_output, wound_prediction_output, wound_confidence_output, wound_confidence_bars]
)
# Tab 2: Depth Estimation
with gr.TabItem("Depth Estimation & 3D Visualization", id=1):
gr.HTML("Depth Estimation and 3D Point Cloud Generation
")
with gr.Row():
depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
with gr.Row():
depth_submit = gr.Button(value="Compute Depth", variant="primary")
depth_points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
label="Number of 3D points (upload image to update max)")
with gr.Row():
depth_focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
label="Focal Length X (pixels)")
depth_focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
label="Focal Length Y (pixels)")
with gr.Row():
depth_gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
depth_raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
depth_point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
gr.Markdown("### 3D Point Cloud Visualization")
gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
depth_3d_plot = gr.Plot(label="3D Point Cloud")
# Event handlers for depth estimation
depth_input_image.change(
fn=update_slider_on_image_upload,
inputs=[depth_input_image],
outputs=[depth_points_slider]
)
depth_submit.click(
on_depth_submit,
inputs=[depth_input_image, depth_points_slider, depth_focal_length_x, depth_focal_length_y],
outputs=[depth_image_slider, depth_gray_depth_file, depth_raw_file, depth_point_cloud_file, depth_3d_plot]
)
# Cross-tab image sharing functionality
# When image is uploaded in wound classification, also update depth estimation
wound_input_image.change(
fn=lambda img: img,
inputs=[wound_input_image],
outputs=[depth_input_image]
)
# When image is uploaded in depth estimation, also update wound classification
depth_input_image.change(
fn=lambda img: img,
inputs=[depth_input_image],
outputs=[wound_input_image]
)
# --- Launch the unified interface ---
if __name__ == "__main__":
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)