Spaces:
Sleeping
Sleeping
import gradio as gr | |
import tensorflow as tf | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import matplotlib | |
import cv2 | |
import io | |
import os | |
matplotlib.use('Agg') # Use non-interactive backend | |
# Load the model using SavedModel format | |
MODEL_PATH = "chest_ct_binary_classifier_densenet_tf_20250427_182239" | |
model = tf.saved_model.load(MODEL_PATH) | |
infer = model.signatures["serving_default"] # Get the inference function | |
# Get input and output tensor names | |
input_tensor_name = list(infer.structured_input_signature[1].keys())[0] | |
output_tensor_name = list(infer.structured_outputs.keys())[0] | |
# Image size - matching what your model was trained on | |
IMG_SIZE = 256 | |
# Function for preprocessing | |
def preprocess_image(image): | |
img = Image.fromarray(image).convert('RGB') | |
img = img.resize((IMG_SIZE, IMG_SIZE)) | |
img_array = np.array(img) / 255.0 | |
return np.expand_dims(img_array, axis=0).astype(np.float32) # Cast to float32 for TF | |
# Make prediction with the SavedModel | |
def predict_with_saved_model(image_tensor): | |
# Create the input tensor with the right name | |
input_dict = {input_tensor_name: image_tensor} | |
# Run inference | |
output = infer(**input_dict) | |
# Get the prediction value | |
prediction = output[output_tensor_name].numpy()[0][0] | |
return prediction | |
# Generate Grad-CAM using the SavedModel | |
# Note: Grad-CAM is more complex with SavedModel format, so we'll use a simplified approach | |
def generate_attention_map(img_array, prediction): | |
# Since getting Grad-CAM from SavedModel is complex, let's use a simplified heatmap | |
# This is a placeholder - in production you may want to implement a proper CAM | |
# For demo purposes, we'll create a simple attention map based on image features | |
gray = cv2.cvtColor(img_array[0].astype(np.float32), cv2.COLOR_RGB2GRAY) | |
blur = cv2.GaussianBlur(gray, (5, 5), 0) | |
# Use simple edge detection as a proxy for "interesting" regions | |
sobelx = cv2.Sobel(blur, cv2.CV_64F, 1, 0, ksize=3) | |
sobely = cv2.Sobel(blur, cv2.CV_64F, 0, 1, ksize=3) | |
magnitude = np.sqrt(sobelx**2 + sobely**2) | |
# Normalize to 0-1 | |
magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8) | |
# Apply sigmoid weighting based on prediction (higher probability = more intensity) | |
weight = 0.5 + (prediction - 0.5) * 0.5 # Scale between 0.5-1 based on prediction | |
magnitude = magnitude * weight | |
# Apply colormap | |
heatmap = cv2.applyColorMap(np.uint8(255 * magnitude), cv2.COLORMAP_JET) | |
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
return heatmap, magnitude | |
# Prediction function with visualization | |
def predict_and_explain(image): | |
if image is None: | |
return None, "Please upload an image.", 0.0 | |
# Preprocess the image | |
preprocessed = preprocess_image(image) | |
# Make prediction | |
prediction = predict_with_saved_model(preprocessed) | |
# Generate attention map | |
heatmap, attention = generate_attention_map(preprocessed, prediction) | |
# Create overlay | |
original_resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE)) | |
superimposed = (0.6 * original_resized) + (0.4 * heatmap) | |
superimposed = superimposed.astype(np.uint8) | |
# Create visualization with matplotlib | |
fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | |
axes[0].imshow(original_resized) | |
axes[0].set_title("Original CT Scan") | |
axes[0].axis('off') | |
axes[1].imshow(heatmap) | |
axes[1].set_title("Feature Map") | |
axes[1].axis('off') | |
axes[2].imshow(superimposed) | |
axes[2].set_title(f"Overlay") | |
axes[2].axis('off') | |
# Add prediction information | |
result_text = f"{'Cancer' if prediction > 0.5 else 'Normal'} (Confidence: {abs(prediction if prediction > 0.5 else 1-prediction):.2%})" | |
fig.suptitle(result_text, fontsize=16) | |
# Convert plot to image | |
buf = io.BytesIO() | |
plt.tight_layout() | |
plt.savefig(buf, format='png') | |
plt.close(fig) | |
buf.seek(0) | |
result_image = np.array(Image.open(buf)) | |
# Return prediction information | |
prediction_class = "Cancer" if prediction > 0.5 else "Normal" | |
confidence = float(prediction if prediction > 0.5 else 1-prediction) | |
return result_image, prediction_class, confidence | |
# Create Gradio interface | |
with gr.Blocks(title="Chest CT Scan Cancer Detection") as demo: | |
gr.Markdown("# Chest CT Scan Cancer Detection") | |
gr.Markdown("Upload a chest CT scan image to detect the presence of cancer.") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Upload CT Scan Image", type="numpy") | |
submit_btn = gr.Button("Analyze Image") | |
with gr.Column(): | |
output_image = gr.Image(label="Analysis Results") | |
prediction_label = gr.Label(label="Prediction") | |
confidence_score = gr.Number(label="Confidence Score") | |
gr.Markdown("### How it works") | |
gr.Markdown(""" | |
This application uses a deep learning model based on DenseNet121 architecture to classify chest CT scans as either 'Normal' or 'Cancer'. | |
The visualization shows: | |
- Left: Original CT scan | |
- Middle: Feature map highlighting areas with distinctive patterns | |
- Right: Overlay of the feature map on the original image | |
The model was trained on a dataset of chest CT scans containing normal images and various types of lung cancer (adenocarcinoma, squamous cell carcinoma, and large cell carcinoma). | |
""") | |
submit_btn.click( | |
predict_and_explain, | |
inputs=input_image, | |
outputs=[output_image, prediction_label, confidence_score] | |
) | |
demo.launch() |