import gradio as gr import numpy as np import tensorflow as tf from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing import image as keras_image from tensorflow.keras import backend as K import matplotlib.pyplot as plt from PIL import Image import io import cv2 import glob import matplotlib import torch import tempfile from gradio_imageslider import ImageSlider import plotly.graph_objects as go import plotly.express as px import open3d as o3d from depth_anything_v2.dpt import DepthAnythingV2 # --- Load models --- # Wound classification model try: wound_model = load_model("checkpoints/keras_model.h5") with open("labels.txt", "r") as f: class_labels = [line.strip() for line in f] except: wound_model = None class_labels = ["No model found"] # Depth estimation model DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' model_configs = { 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} } encoder = 'vitl' try: depth_model = DepthAnythingV2(**model_configs[encoder]) state_dict = torch.load(f'checkpoints/depth_anything_v2_{encoder}.pth', map_location="cpu") depth_model.load_state_dict(state_dict) depth_model = depth_model.to(DEVICE).eval() except: depth_model = None # --- Wound Classification Functions --- def preprocess_input(img): img = img.resize((224, 224)) arr = keras_image.img_to_array(img) arr = arr / 255.0 return np.expand_dims(arr, axis=0) def get_gradcam_heatmap(img_array, model, class_index, last_conv_layer_name="conv5_block3_out"): try: target_layer = model.get_layer(last_conv_layer_name) except: for layer in model.layers: if 'conv' in layer.name.lower(): target_layer = layer break else: return None grad_model = tf.keras.models.Model( [model.inputs], [target_layer.output, model.output] ) with tf.GradientTape() as tape: conv_outputs, predictions = grad_model(img_array) loss = predictions[:, class_index] grads = tape.gradient(loss, conv_outputs) if grads is None: return None grads = grads[0] pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) conv_outputs = conv_outputs[0] heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1) heatmap = np.maximum(heatmap, 0) heatmap = heatmap / np.max(heatmap + K.epsilon()) return heatmap.numpy() def overlay_gradcam(original_img, heatmap): if heatmap is None: return original_img heatmap = cv2.resize(heatmap, original_img.size) heatmap = np.maximum(heatmap, 0) if np.max(heatmap) != 0: heatmap /= np.max(heatmap) heatmap = np.uint8(255 * heatmap) heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) original_array = np.array(original_img.convert("RGB")) superimposed_img = cv2.addWeighted(original_array, 0.6, heatmap_color, 0.4, 0) return Image.fromarray(superimposed_img) def classify_and_explain(img): if img is None or wound_model is None: return None, {}, "No image provided or model not available" img_array = preprocess_input(img) predictions = wound_model.predict(img_array, verbose=0)[0] pred_idx = int(np.argmax(predictions)) pred_class = class_labels[pred_idx] confidence_dict = {class_labels[i]: float(predictions[i]) for i in range(len(class_labels))} try: heatmap = get_gradcam_heatmap(img_array, wound_model, pred_idx) gradcam_img = overlay_gradcam(img.resize((224, 224)), heatmap) except Exception as e: print(f"Grad-CAM error: {e}") gradcam_img = img.resize((224, 224)) return gradcam_img, confidence_dict def create_confidence_bars(confidence_dict): html_content = "
" for class_name, confidence in confidence_dict.items(): percentage = confidence * 100 if percentage > 70: css_class = "confidence-high" elif percentage > 40: css_class = "confidence-medium" else: css_class = "confidence-low" html_content += f"""
{class_name} {percentage:.1f}%
""" html_content += "
" return html_content def enhanced_classify_and_explain(img): if img is None: return None, "No image provided", 0, "" gradcam_img, confidence_dict = classify_and_explain(img) if isinstance(confidence_dict, str): # Error case return None, confidence_dict, 0, "" pred_class = max(confidence_dict, key=confidence_dict.get) confidence = confidence_dict[pred_class] confidence_bars_html = create_confidence_bars(confidence_dict) return gradcam_img, pred_class, confidence, confidence_bars_html # --- Depth Estimation Functions --- def predict_depth(image): if depth_model is None: return None return depth_model.infer_image(image) def calculate_max_points(image): if image is None: return 10000 h, w = image.shape[:2] max_points = h * w * 3 return max(1000, min(max_points, 1000000)) def update_slider_on_image_upload(image): max_points = calculate_max_points(image) default_value = min(10000, max_points // 10) return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000, label=f"Number of 3D points (max: {max_points:,})") def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=100000): h, w = depth_map.shape step = max(1, int(np.sqrt(h * w / max_points))) y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] x_cam = (x_coords - w / 2) / focal_length_x y_cam = (y_coords - h / 2) / focal_length_y depth_values = depth_map[::step, ::step] x_3d = x_cam * depth_values y_3d = y_cam * depth_values z_3d = depth_values points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1) image_colors = image[::step, ::step, :] colors = image_colors.reshape(-1, 3) / 255.0 pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) pcd.colors = o3d.utility.Vector3dVector(colors) return pcd def create_enhanced_3d_visualization(image, depth_map, max_points=10000): h, w = depth_map.shape step = max(1, int(np.sqrt(h * w / max_points))) y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] focal_length = 470.4 x_cam = (x_coords - w / 2) / focal_length y_cam = (y_coords - h / 2) / focal_length depth_values = depth_map[::step, ::step] x_3d = x_cam * depth_values y_3d = y_cam * depth_values z_3d = depth_values x_flat = x_3d.flatten() y_flat = y_3d.flatten() z_flat = z_3d.flatten() image_colors = image[::step, ::step, :] colors_flat = image_colors.reshape(-1, 3) fig = go.Figure(data=[go.Scatter3d( x=x_flat, y=y_flat, z=z_flat, mode='markers', marker=dict( size=1.5, color=colors_flat, opacity=0.9 ), hovertemplate='3D Position: (%{x:.3f}, %{y:.3f}, %{z:.3f})
' + 'Depth: %{z:.2f}
' + '' )]) fig.update_layout( title="3D Point Cloud Visualization (Camera Projection)", scene=dict( xaxis_title="X (meters)", yaxis_title="Y (meters)", zaxis_title="Z (meters)", camera=dict( eye=dict(x=2.0, y=2.0, z=2.0), center=dict(x=0, y=0, z=0), up=dict(x=0, y=0, z=1) ), aspectmode='data' ), width=700, height=600 ) return fig def on_depth_submit(image, num_points, focal_x, focal_y): if image is None or depth_model is None: return None, None, None, None, None original_image = image.copy() h, w = image.shape[:2] depth = predict_depth(image[:, :, ::-1]) if depth is None: return None, None, None, None, None raw_depth = Image.fromarray(depth.astype('uint16')) tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) raw_depth.save(tmp_raw_depth.name) depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 depth = depth.astype(np.uint8) cmap = matplotlib.colormaps.get_cmap('Spectral_r') colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8) gray_depth = Image.fromarray(depth) tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) gray_depth.save(tmp_gray_depth.name) pcd = create_point_cloud(original_image, depth, focal_x, focal_y, max_points=num_points) tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False) o3d.io.write_point_cloud(tmp_pointcloud.name, pcd) depth_3d = create_enhanced_3d_visualization(original_image, depth, max_points=num_points) return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d] # --- Custom CSS for Unified Interface --- css = """ /* Minimal dark theme styling */ .main-header { text-align: center; margin-bottom: 2rem; padding: 2rem 0; } .main-header h1 { font-size: 2.5rem; margin-bottom: 0.5rem; font-weight: 600; } .main-header p { font-size: 1.1rem; opacity: 0.8; } .section-title { font-size: 1.2rem; font-weight: 600; margin-bottom: 15px; padding-bottom: 8px; border-bottom: 1px solid var(--border-color-primary); } .confidence-container { margin: 15px 0; padding: 15px; border-radius: 8px; background: var(--background-secondary); border: 1px solid var(--border-color-primary); } .confidence-bar { height: 20px; border-radius: 4px; margin: 6px 0; background: var(--primary-500); transition: width 0.3s ease; } /* Simple confidence bar colors */ .confidence-high { background: var(--success-500); } .confidence-medium { background: var(--warning-500); } .confidence-low { background: var(--error-500); } /* Minimal spacing and layout */ .gradio-container { max-width: 100%; margin: 0; padding: 20px; width: 100%; } /* Clean image styling */ .gradio-image { border-radius: 8px; border: 1px solid var(--border-color-primary); } /* Simple button styling */ .gradio-button { border-radius: 6px; font-weight: 500; } /* Clean form elements */ .gradio-textbox, .gradio-number, .gradio-slider { border-radius: 6px; border: 1px solid var(--border-color-primary); } /* Tab styling */ .gradio-tabs { border-radius: 8px; overflow: hidden; } /* File upload styling */ .gradio-file { border-radius: 6px; border: 1px solid var(--border-color-primary); } /* Plot styling */ .gradio-plot { border-radius: 8px; border: 1px solid var(--border-color-primary); } /* Full width and height layout */ body, html { margin: 0; padding: 0; width: 100%; height: 100%; } #root { width: 100%; height: 100%; } /* Ensure Gradio uses full width */ .gradio-container { min-height: 100vh; } /* Responsive adjustments */ @media (max-width: 768px) { .main-header h1 { font-size: 2rem; } .gradio-container { padding: 10px; } } """ # --- Create Unified Interface --- with gr.Blocks(css=css, title="Medical AI Suite") as demo: gr.HTML("""

Medical AI Suite

Advanced AI-powered medical image analysis and 3D visualization

""") with gr.Tabs() as tabs: # Tab 1: Wound Classification with gr.TabItem("Wound Classification", id=0): gr.HTML("
Wound Classification with Grad-CAM Visualization
") with gr.Row(): with gr.Column(scale=1): gr.HTML("
Input Image
") wound_input_image = gr.Image( label="Upload wound image", type="pil", height=350, container=True ) with gr.Column(scale=1): gr.HTML("
Analysis Results
") wound_prediction_output = gr.Textbox( label="Predicted Wound Type", interactive=False, container=True ) wound_confidence_output = gr.Number( label="Confidence Score", interactive=False, container=True ) wound_confidence_bars = gr.HTML( label="Confidence Scores by Class", container=True ) with gr.Row(): with gr.Column(): gr.HTML("
Model Focus Visualization
") wound_cam_output = gr.Image( label="Grad-CAM Heatmap - Shows which areas the model focused on", height=350, container=True ) # Event handlers for wound classification wound_input_image.change( fn=enhanced_classify_and_explain, inputs=[wound_input_image], outputs=[wound_cam_output, wound_prediction_output, wound_confidence_output, wound_confidence_bars] ) # Tab 2: Depth Estimation with gr.TabItem("Depth Estimation & 3D Visualization", id=1): gr.HTML("
Depth Estimation and 3D Point Cloud Generation
") with gr.Row(): depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output') with gr.Row(): depth_submit = gr.Button(value="Compute Depth", variant="primary") depth_points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000, label="Number of 3D points (upload image to update max)") with gr.Row(): depth_focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, label="Focal Length X (pixels)") depth_focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, label="Focal Length Y (pixels)") with gr.Row(): depth_gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download") depth_raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download") depth_point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download") gr.Markdown("### 3D Point Cloud Visualization") gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.") depth_3d_plot = gr.Plot(label="3D Point Cloud") # Event handlers for depth estimation depth_input_image.change( fn=update_slider_on_image_upload, inputs=[depth_input_image], outputs=[depth_points_slider] ) depth_submit.click( on_depth_submit, inputs=[depth_input_image, depth_points_slider, depth_focal_length_x, depth_focal_length_y], outputs=[depth_image_slider, depth_gray_depth_file, depth_raw_file, depth_point_cloud_file, depth_3d_plot] ) # Cross-tab image sharing functionality # When image is uploaded in wound classification, also update depth estimation wound_input_image.change( fn=lambda img: img, inputs=[wound_input_image], outputs=[depth_input_image] ) # When image is uploaded in depth estimation, also update wound classification depth_input_image.change( fn=lambda img: img, inputs=[depth_input_image], outputs=[wound_input_image] ) # --- Launch the unified interface --- if __name__ == "__main__": demo.queue().launch( server_name="0.0.0.0", server_port=7860, share=True, show_error=True )