Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing import image as keras_image | |
from tensorflow.keras import backend as K | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import io | |
import cv2 | |
import glob | |
import matplotlib | |
import torch | |
import tempfile | |
from gradio_imageslider import ImageSlider | |
import plotly.graph_objects as go | |
import plotly.express as px | |
import open3d as o3d | |
from depth_anything_v2.dpt import DepthAnythingV2 | |
# --- Load models --- | |
# Wound classification model | |
try: | |
wound_model = load_model("checkpoints/keras_model.h5") | |
with open("labels.txt", "r") as f: | |
class_labels = [line.strip() for line in f] | |
except: | |
wound_model = None | |
class_labels = ["No model found"] | |
# Depth estimation model | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' | |
model_configs = { | |
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, | |
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, | |
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, | |
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} | |
} | |
encoder = 'vitl' | |
try: | |
depth_model = DepthAnythingV2(**model_configs[encoder]) | |
state_dict = torch.load(f'checkpoints/depth_anything_v2_{encoder}.pth', map_location="cpu") | |
depth_model.load_state_dict(state_dict) | |
depth_model = depth_model.to(DEVICE).eval() | |
except: | |
depth_model = None | |
# --- Wound Classification Functions --- | |
def preprocess_input(img): | |
img = img.resize((224, 224)) | |
arr = keras_image.img_to_array(img) | |
arr = arr / 255.0 | |
return np.expand_dims(arr, axis=0) | |
def get_gradcam_heatmap(img_array, model, class_index, last_conv_layer_name="conv5_block3_out"): | |
try: | |
target_layer = model.get_layer(last_conv_layer_name) | |
except: | |
for layer in model.layers: | |
if 'conv' in layer.name.lower(): | |
target_layer = layer | |
break | |
else: | |
return None | |
grad_model = tf.keras.models.Model( | |
[model.inputs], [target_layer.output, model.output] | |
) | |
with tf.GradientTape() as tape: | |
conv_outputs, predictions = grad_model(img_array) | |
loss = predictions[:, class_index] | |
grads = tape.gradient(loss, conv_outputs) | |
if grads is None: | |
return None | |
grads = grads[0] | |
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) | |
conv_outputs = conv_outputs[0] | |
heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1) | |
heatmap = np.maximum(heatmap, 0) | |
heatmap = heatmap / np.max(heatmap + K.epsilon()) | |
return heatmap.numpy() | |
def overlay_gradcam(original_img, heatmap): | |
if heatmap is None: | |
return original_img | |
heatmap = cv2.resize(heatmap, original_img.size) | |
heatmap = np.maximum(heatmap, 0) | |
if np.max(heatmap) != 0: | |
heatmap /= np.max(heatmap) | |
heatmap = np.uint8(255 * heatmap) | |
heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
original_array = np.array(original_img.convert("RGB")) | |
superimposed_img = cv2.addWeighted(original_array, 0.6, heatmap_color, 0.4, 0) | |
return Image.fromarray(superimposed_img) | |
def classify_and_explain(img): | |
if img is None or wound_model is None: | |
return None, {}, "No image provided or model not available" | |
img_array = preprocess_input(img) | |
predictions = wound_model.predict(img_array, verbose=0)[0] | |
pred_idx = int(np.argmax(predictions)) | |
pred_class = class_labels[pred_idx] | |
confidence_dict = {class_labels[i]: float(predictions[i]) for i in range(len(class_labels))} | |
try: | |
heatmap = get_gradcam_heatmap(img_array, wound_model, pred_idx) | |
gradcam_img = overlay_gradcam(img.resize((224, 224)), heatmap) | |
except Exception as e: | |
print(f"Grad-CAM error: {e}") | |
gradcam_img = img.resize((224, 224)) | |
return gradcam_img, confidence_dict | |
def create_confidence_bars(confidence_dict): | |
html_content = "<div class='confidence-container'>" | |
for class_name, confidence in confidence_dict.items(): | |
percentage = confidence * 100 | |
if percentage > 70: | |
css_class = "confidence-high" | |
elif percentage > 40: | |
css_class = "confidence-medium" | |
else: | |
css_class = "confidence-low" | |
html_content += f""" | |
<div style='margin: 12px 0;'> | |
<div style='display: flex; justify-content: space-between; margin-bottom: 8px;'> | |
<span style='font-weight: bold;'>{class_name}</span> | |
<span style='font-weight: bold;'>{percentage:.1f}%</span> | |
</div> | |
<div class='confidence-bar {css_class}' style='width: {percentage}%;'></div> | |
</div> | |
""" | |
html_content += "</div>" | |
return html_content | |
def enhanced_classify_and_explain(img): | |
if img is None: | |
return None, "No image provided", 0, "" | |
gradcam_img, confidence_dict = classify_and_explain(img) | |
if isinstance(confidence_dict, str): # Error case | |
return None, confidence_dict, 0, "" | |
pred_class = max(confidence_dict, key=confidence_dict.get) | |
confidence = confidence_dict[pred_class] | |
confidence_bars_html = create_confidence_bars(confidence_dict) | |
return gradcam_img, pred_class, confidence, confidence_bars_html | |
# --- Depth Estimation Functions --- | |
def predict_depth(image): | |
if depth_model is None: | |
return None | |
return depth_model.infer_image(image) | |
def calculate_max_points(image): | |
if image is None: | |
return 10000 | |
h, w = image.shape[:2] | |
max_points = h * w * 3 | |
return max(1000, min(max_points, 1000000)) | |
def update_slider_on_image_upload(image): | |
max_points = calculate_max_points(image) | |
default_value = min(10000, max_points // 10) | |
return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000, | |
label=f"Number of 3D points (max: {max_points:,})") | |
def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=100000): | |
h, w = depth_map.shape | |
step = max(1, int(np.sqrt(h * w / max_points))) | |
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] | |
x_cam = (x_coords - w / 2) / focal_length_x | |
y_cam = (y_coords - h / 2) / focal_length_y | |
depth_values = depth_map[::step, ::step] | |
x_3d = x_cam * depth_values | |
y_3d = y_cam * depth_values | |
z_3d = depth_values | |
points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1) | |
image_colors = image[::step, ::step, :] | |
colors = image_colors.reshape(-1, 3) / 255.0 | |
pcd = o3d.geometry.PointCloud() | |
pcd.points = o3d.utility.Vector3dVector(points) | |
pcd.colors = o3d.utility.Vector3dVector(colors) | |
return pcd | |
def create_enhanced_3d_visualization(image, depth_map, max_points=10000): | |
h, w = depth_map.shape | |
step = max(1, int(np.sqrt(h * w / max_points))) | |
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] | |
focal_length = 470.4 | |
x_cam = (x_coords - w / 2) / focal_length | |
y_cam = (y_coords - h / 2) / focal_length | |
depth_values = depth_map[::step, ::step] | |
x_3d = x_cam * depth_values | |
y_3d = y_cam * depth_values | |
z_3d = depth_values | |
x_flat = x_3d.flatten() | |
y_flat = y_3d.flatten() | |
z_flat = z_3d.flatten() | |
image_colors = image[::step, ::step, :] | |
colors_flat = image_colors.reshape(-1, 3) | |
fig = go.Figure(data=[go.Scatter3d( | |
x=x_flat, | |
y=y_flat, | |
z=z_flat, | |
mode='markers', | |
marker=dict( | |
size=1.5, | |
color=colors_flat, | |
opacity=0.9 | |
), | |
hovertemplate='<b>3D Position:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<br>' + | |
'<b>Depth:</b> %{z:.2f}<br>' + | |
'<extra></extra>' | |
)]) | |
fig.update_layout( | |
title="3D Point Cloud Visualization (Camera Projection)", | |
scene=dict( | |
xaxis_title="X (meters)", | |
yaxis_title="Y (meters)", | |
zaxis_title="Z (meters)", | |
camera=dict( | |
eye=dict(x=2.0, y=2.0, z=2.0), | |
center=dict(x=0, y=0, z=0), | |
up=dict(x=0, y=0, z=1) | |
), | |
aspectmode='data' | |
), | |
width=700, | |
height=600 | |
) | |
return fig | |
def on_depth_submit(image, num_points, focal_x, focal_y): | |
if image is None or depth_model is None: | |
return None, None, None, None, None | |
original_image = image.copy() | |
h, w = image.shape[:2] | |
depth = predict_depth(image[:, :, ::-1]) | |
if depth is None: | |
return None, None, None, None, None | |
raw_depth = Image.fromarray(depth.astype('uint16')) | |
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) | |
raw_depth.save(tmp_raw_depth.name) | |
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 | |
depth = depth.astype(np.uint8) | |
cmap = matplotlib.colormaps.get_cmap('Spectral_r') | |
colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8) | |
gray_depth = Image.fromarray(depth) | |
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) | |
gray_depth.save(tmp_gray_depth.name) | |
pcd = create_point_cloud(original_image, depth, focal_x, focal_y, max_points=num_points) | |
tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False) | |
o3d.io.write_point_cloud(tmp_pointcloud.name, pcd) | |
depth_3d = create_enhanced_3d_visualization(original_image, depth, max_points=num_points) | |
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d] | |
# --- Custom CSS for Unified Interface --- | |
css = """ | |
/* Minimal dark theme styling */ | |
.main-header { | |
text-align: center; | |
margin-bottom: 2rem; | |
padding: 2rem 0; | |
} | |
.main-header h1 { | |
font-size: 2.5rem; | |
margin-bottom: 0.5rem; | |
font-weight: 600; | |
} | |
.main-header p { | |
font-size: 1.1rem; | |
opacity: 0.8; | |
} | |
.section-title { | |
font-size: 1.2rem; | |
font-weight: 600; | |
margin-bottom: 15px; | |
padding-bottom: 8px; | |
border-bottom: 1px solid var(--border-color-primary); | |
} | |
.confidence-container { | |
margin: 15px 0; | |
padding: 15px; | |
border-radius: 8px; | |
background: var(--background-secondary); | |
border: 1px solid var(--border-color-primary); | |
} | |
.confidence-bar { | |
height: 20px; | |
border-radius: 4px; | |
margin: 6px 0; | |
background: var(--primary-500); | |
transition: width 0.3s ease; | |
} | |
/* Simple confidence bar colors */ | |
.confidence-high { | |
background: var(--success-500); | |
} | |
.confidence-medium { | |
background: var(--warning-500); | |
} | |
.confidence-low { | |
background: var(--error-500); | |
} | |
/* Minimal spacing and layout */ | |
.gradio-container { | |
max-width: 100%; | |
margin: 0; | |
padding: 20px; | |
width: 100%; | |
} | |
/* Clean image styling */ | |
.gradio-image { | |
border-radius: 8px; | |
border: 1px solid var(--border-color-primary); | |
} | |
/* Simple button styling */ | |
.gradio-button { | |
border-radius: 6px; | |
font-weight: 500; | |
} | |
/* Clean form elements */ | |
.gradio-textbox, .gradio-number, .gradio-slider { | |
border-radius: 6px; | |
border: 1px solid var(--border-color-primary); | |
} | |
/* Tab styling */ | |
.gradio-tabs { | |
border-radius: 8px; | |
overflow: hidden; | |
} | |
/* File upload styling */ | |
.gradio-file { | |
border-radius: 6px; | |
border: 1px solid var(--border-color-primary); | |
} | |
/* Plot styling */ | |
.gradio-plot { | |
border-radius: 8px; | |
border: 1px solid var(--border-color-primary); | |
} | |
/* Full width and height layout */ | |
body, html { | |
margin: 0; | |
padding: 0; | |
width: 100%; | |
height: 100%; | |
} | |
#root { | |
width: 100%; | |
height: 100%; | |
} | |
/* Ensure Gradio uses full width */ | |
.gradio-container { | |
min-height: 100vh; | |
} | |
/* Responsive adjustments */ | |
@media (max-width: 768px) { | |
.main-header h1 { | |
font-size: 2rem; | |
} | |
.gradio-container { | |
padding: 10px; | |
} | |
} | |
""" | |
# --- Create Unified Interface --- | |
with gr.Blocks(css=css, title="Medical AI Suite") as demo: | |
gr.HTML(""" | |
<div class="main-header"> | |
<h1>Medical AI Suite</h1> | |
<p>Advanced AI-powered medical image analysis and 3D visualization</p> | |
</div> | |
""") | |
with gr.Tabs() as tabs: | |
# Tab 1: Wound Classification | |
with gr.TabItem("Wound Classification", id=0): | |
gr.HTML("<div class='section-title'>Wound Classification with Grad-CAM Visualization</div>") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.HTML("<div class='section-title'>Input Image</div>") | |
wound_input_image = gr.Image( | |
label="Upload wound image", | |
type="pil", | |
height=350, | |
container=True | |
) | |
with gr.Column(scale=1): | |
gr.HTML("<div class='section-title'>Analysis Results</div>") | |
wound_prediction_output = gr.Textbox( | |
label="Predicted Wound Type", | |
interactive=False, | |
container=True | |
) | |
wound_confidence_output = gr.Number( | |
label="Confidence Score", | |
interactive=False, | |
container=True | |
) | |
wound_confidence_bars = gr.HTML( | |
label="Confidence Scores by Class", | |
container=True | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML("<div class='section-title'>Model Focus Visualization</div>") | |
wound_cam_output = gr.Image( | |
label="Grad-CAM Heatmap - Shows which areas the model focused on", | |
height=350, | |
container=True | |
) | |
# Event handlers for wound classification | |
wound_input_image.change( | |
fn=enhanced_classify_and_explain, | |
inputs=[wound_input_image], | |
outputs=[wound_cam_output, wound_prediction_output, wound_confidence_output, wound_confidence_bars] | |
) | |
# Tab 2: Depth Estimation | |
with gr.TabItem("Depth Estimation & 3D Visualization", id=1): | |
gr.HTML("<div class='section-title'>Depth Estimation and 3D Point Cloud Generation</div>") | |
with gr.Row(): | |
depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') | |
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output') | |
with gr.Row(): | |
depth_submit = gr.Button(value="Compute Depth", variant="primary") | |
depth_points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000, | |
label="Number of 3D points (upload image to update max)") | |
with gr.Row(): | |
depth_focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, | |
label="Focal Length X (pixels)") | |
depth_focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, | |
label="Focal Length Y (pixels)") | |
with gr.Row(): | |
depth_gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download") | |
depth_raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download") | |
depth_point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download") | |
gr.Markdown("### 3D Point Cloud Visualization") | |
gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.") | |
depth_3d_plot = gr.Plot(label="3D Point Cloud") | |
# Event handlers for depth estimation | |
depth_input_image.change( | |
fn=update_slider_on_image_upload, | |
inputs=[depth_input_image], | |
outputs=[depth_points_slider] | |
) | |
depth_submit.click( | |
on_depth_submit, | |
inputs=[depth_input_image, depth_points_slider, depth_focal_length_x, depth_focal_length_y], | |
outputs=[depth_image_slider, depth_gray_depth_file, depth_raw_file, depth_point_cloud_file, depth_3d_plot] | |
) | |
# Cross-tab image sharing functionality | |
# When image is uploaded in wound classification, also update depth estimation | |
wound_input_image.change( | |
fn=lambda img: img, | |
inputs=[wound_input_image], | |
outputs=[depth_input_image] | |
) | |
# When image is uploaded in depth estimation, also update wound classification | |
depth_input_image.change( | |
fn=lambda img: img, | |
inputs=[depth_input_image], | |
outputs=[wound_input_image] | |
) | |
# --- Launch the unified interface --- | |
if __name__ == "__main__": | |
demo.queue().launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, | |
show_error=True | |
) |