IsmatS's picture
Create app.py
2bbade7 verified
raw
history blame
5.7 kB
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib
import cv2
import io
import os
matplotlib.use('Agg') # Use non-interactive backend
# Load the model using SavedModel format
MODEL_PATH = "chest_ct_binary_classifier_densenet_tf_20250427_182239"
model = tf.saved_model.load(MODEL_PATH)
infer = model.signatures["serving_default"] # Get the inference function
# Get input and output tensor names
input_tensor_name = list(infer.structured_input_signature[1].keys())[0]
output_tensor_name = list(infer.structured_outputs.keys())[0]
# Image size - matching what your model was trained on
IMG_SIZE = 256
# Function for preprocessing
def preprocess_image(image):
img = Image.fromarray(image).convert('RGB')
img = img.resize((IMG_SIZE, IMG_SIZE))
img_array = np.array(img) / 255.0
return np.expand_dims(img_array, axis=0).astype(np.float32) # Cast to float32 for TF
# Make prediction with the SavedModel
def predict_with_saved_model(image_tensor):
# Create the input tensor with the right name
input_dict = {input_tensor_name: image_tensor}
# Run inference
output = infer(**input_dict)
# Get the prediction value
prediction = output[output_tensor_name].numpy()[0][0]
return prediction
# Generate Grad-CAM using the SavedModel
# Note: Grad-CAM is more complex with SavedModel format, so we'll use a simplified approach
def generate_attention_map(img_array, prediction):
# Since getting Grad-CAM from SavedModel is complex, let's use a simplified heatmap
# This is a placeholder - in production you may want to implement a proper CAM
# For demo purposes, we'll create a simple attention map based on image features
gray = cv2.cvtColor(img_array[0].astype(np.float32), cv2.COLOR_RGB2GRAY)
blur = cv2.GaussianBlur(gray, (5, 5), 0)
# Use simple edge detection as a proxy for "interesting" regions
sobelx = cv2.Sobel(blur, cv2.CV_64F, 1, 0, ksize=3)
sobely = cv2.Sobel(blur, cv2.CV_64F, 0, 1, ksize=3)
magnitude = np.sqrt(sobelx**2 + sobely**2)
# Normalize to 0-1
magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8)
# Apply sigmoid weighting based on prediction (higher probability = more intensity)
weight = 0.5 + (prediction - 0.5) * 0.5 # Scale between 0.5-1 based on prediction
magnitude = magnitude * weight
# Apply colormap
heatmap = cv2.applyColorMap(np.uint8(255 * magnitude), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
return heatmap, magnitude
# Prediction function with visualization
def predict_and_explain(image):
if image is None:
return None, "Please upload an image.", 0.0
# Preprocess the image
preprocessed = preprocess_image(image)
# Make prediction
prediction = predict_with_saved_model(preprocessed)
# Generate attention map
heatmap, attention = generate_attention_map(preprocessed, prediction)
# Create overlay
original_resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
superimposed = (0.6 * original_resized) + (0.4 * heatmap)
superimposed = superimposed.astype(np.uint8)
# Create visualization with matplotlib
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(original_resized)
axes[0].set_title("Original CT Scan")
axes[0].axis('off')
axes[1].imshow(heatmap)
axes[1].set_title("Feature Map")
axes[1].axis('off')
axes[2].imshow(superimposed)
axes[2].set_title(f"Overlay")
axes[2].axis('off')
# Add prediction information
result_text = f"{'Cancer' if prediction > 0.5 else 'Normal'} (Confidence: {abs(prediction if prediction > 0.5 else 1-prediction):.2%})"
fig.suptitle(result_text, fontsize=16)
# Convert plot to image
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
result_image = np.array(Image.open(buf))
# Return prediction information
prediction_class = "Cancer" if prediction > 0.5 else "Normal"
confidence = float(prediction if prediction > 0.5 else 1-prediction)
return result_image, prediction_class, confidence
# Create Gradio interface
with gr.Blocks(title="Chest CT Scan Cancer Detection") as demo:
gr.Markdown("# Chest CT Scan Cancer Detection")
gr.Markdown("Upload a chest CT scan image to detect the presence of cancer.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload CT Scan Image", type="numpy")
submit_btn = gr.Button("Analyze Image")
with gr.Column():
output_image = gr.Image(label="Analysis Results")
prediction_label = gr.Label(label="Prediction")
confidence_score = gr.Number(label="Confidence Score")
gr.Markdown("### How it works")
gr.Markdown("""
This application uses a deep learning model based on DenseNet121 architecture to classify chest CT scans as either 'Normal' or 'Cancer'.
The visualization shows:
- Left: Original CT scan
- Middle: Feature map highlighting areas with distinctive patterns
- Right: Overlay of the feature map on the original image
The model was trained on a dataset of chest CT scans containing normal images and various types of lung cancer (adenocarcinoma, squamous cell carcinoma, and large cell carcinoma).
""")
submit_btn.click(
predict_and_explain,
inputs=input_image,
outputs=[output_image, prediction_label, confidence_score]
)
demo.launch()