File size: 5,704 Bytes
2bbade7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib
import cv2
import io
import os
matplotlib.use('Agg')  # Use non-interactive backend

# Load the model using SavedModel format
MODEL_PATH = "chest_ct_binary_classifier_densenet_tf_20250427_182239"
model = tf.saved_model.load(MODEL_PATH)
infer = model.signatures["serving_default"]  # Get the inference function

# Get input and output tensor names
input_tensor_name = list(infer.structured_input_signature[1].keys())[0]
output_tensor_name = list(infer.structured_outputs.keys())[0]

# Image size - matching what your model was trained on
IMG_SIZE = 256

# Function for preprocessing
def preprocess_image(image):
    img = Image.fromarray(image).convert('RGB')
    img = img.resize((IMG_SIZE, IMG_SIZE))
    img_array = np.array(img) / 255.0
    return np.expand_dims(img_array, axis=0).astype(np.float32)  # Cast to float32 for TF

# Make prediction with the SavedModel
def predict_with_saved_model(image_tensor):
    # Create the input tensor with the right name
    input_dict = {input_tensor_name: image_tensor}
    # Run inference
    output = infer(**input_dict)
    # Get the prediction value
    prediction = output[output_tensor_name].numpy()[0][0]
    return prediction

# Generate Grad-CAM using the SavedModel
# Note: Grad-CAM is more complex with SavedModel format, so we'll use a simplified approach
def generate_attention_map(img_array, prediction):
    # Since getting Grad-CAM from SavedModel is complex, let's use a simplified heatmap
    # This is a placeholder - in production you may want to implement a proper CAM
    
    # For demo purposes, we'll create a simple attention map based on image features
    gray = cv2.cvtColor(img_array[0].astype(np.float32), cv2.COLOR_RGB2GRAY)
    blur = cv2.GaussianBlur(gray, (5, 5), 0)
    
    # Use simple edge detection as a proxy for "interesting" regions
    sobelx = cv2.Sobel(blur, cv2.CV_64F, 1, 0, ksize=3)
    sobely = cv2.Sobel(blur, cv2.CV_64F, 0, 1, ksize=3)
    magnitude = np.sqrt(sobelx**2 + sobely**2)
    
    # Normalize to 0-1
    magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8)
    
    # Apply sigmoid weighting based on prediction (higher probability = more intensity)
    weight = 0.5 + (prediction - 0.5) * 0.5  # Scale between 0.5-1 based on prediction
    magnitude = magnitude * weight
    
    # Apply colormap
    heatmap = cv2.applyColorMap(np.uint8(255 * magnitude), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    
    return heatmap, magnitude

# Prediction function with visualization
def predict_and_explain(image):
    if image is None:
        return None, "Please upload an image.", 0.0
    
    # Preprocess the image
    preprocessed = preprocess_image(image)
    
    # Make prediction
    prediction = predict_with_saved_model(preprocessed)
    
    # Generate attention map
    heatmap, attention = generate_attention_map(preprocessed, prediction)
    
    # Create overlay
    original_resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
    superimposed = (0.6 * original_resized) + (0.4 * heatmap)
    superimposed = superimposed.astype(np.uint8)
    
    # Create visualization with matplotlib
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(original_resized)
    axes[0].set_title("Original CT Scan")
    axes[0].axis('off')
    
    axes[1].imshow(heatmap)
    axes[1].set_title("Feature Map")
    axes[1].axis('off')
    
    axes[2].imshow(superimposed)
    axes[2].set_title(f"Overlay")
    axes[2].axis('off')
    
    # Add prediction information
    result_text = f"{'Cancer' if prediction > 0.5 else 'Normal'} (Confidence: {abs(prediction if prediction > 0.5 else 1-prediction):.2%})"
    fig.suptitle(result_text, fontsize=16)
    
    # Convert plot to image
    buf = io.BytesIO()
    plt.tight_layout()
    plt.savefig(buf, format='png')
    plt.close(fig)
    buf.seek(0)
    result_image = np.array(Image.open(buf))
    
    # Return prediction information
    prediction_class = "Cancer" if prediction > 0.5 else "Normal"
    confidence = float(prediction if prediction > 0.5 else 1-prediction)
    
    return result_image, prediction_class, confidence

# Create Gradio interface
with gr.Blocks(title="Chest CT Scan Cancer Detection") as demo:
    gr.Markdown("# Chest CT Scan Cancer Detection")
    gr.Markdown("Upload a chest CT scan image to detect the presence of cancer.")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Upload CT Scan Image", type="numpy")
            submit_btn = gr.Button("Analyze Image")
        
        with gr.Column():
            output_image = gr.Image(label="Analysis Results")
            prediction_label = gr.Label(label="Prediction")
            confidence_score = gr.Number(label="Confidence Score")
    
    gr.Markdown("### How it works")
    gr.Markdown("""
    This application uses a deep learning model based on DenseNet121 architecture to classify chest CT scans as either 'Normal' or 'Cancer'.
    
    The visualization shows:
    - Left: Original CT scan
    - Middle: Feature map highlighting areas with distinctive patterns
    - Right: Overlay of the feature map on the original image
    
    The model was trained on a dataset of chest CT scans containing normal images and various types of lung cancer (adenocarcinoma, squamous cell carcinoma, and large cell carcinoma).
    """)
    
    submit_btn.click(
        predict_and_explain,
        inputs=input_image,
        outputs=[output_image, prediction_label, confidence_score]
    )

demo.launch()