IsmatS commited on
Commit
2bbade7
·
verified ·
1 Parent(s): 8b05cab

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib
7
+ import cv2
8
+ import io
9
+ import os
10
+ matplotlib.use('Agg') # Use non-interactive backend
11
+
12
+ # Load the model using SavedModel format
13
+ MODEL_PATH = "chest_ct_binary_classifier_densenet_tf_20250427_182239"
14
+ model = tf.saved_model.load(MODEL_PATH)
15
+ infer = model.signatures["serving_default"] # Get the inference function
16
+
17
+ # Get input and output tensor names
18
+ input_tensor_name = list(infer.structured_input_signature[1].keys())[0]
19
+ output_tensor_name = list(infer.structured_outputs.keys())[0]
20
+
21
+ # Image size - matching what your model was trained on
22
+ IMG_SIZE = 256
23
+
24
+ # Function for preprocessing
25
+ def preprocess_image(image):
26
+ img = Image.fromarray(image).convert('RGB')
27
+ img = img.resize((IMG_SIZE, IMG_SIZE))
28
+ img_array = np.array(img) / 255.0
29
+ return np.expand_dims(img_array, axis=0).astype(np.float32) # Cast to float32 for TF
30
+
31
+ # Make prediction with the SavedModel
32
+ def predict_with_saved_model(image_tensor):
33
+ # Create the input tensor with the right name
34
+ input_dict = {input_tensor_name: image_tensor}
35
+ # Run inference
36
+ output = infer(**input_dict)
37
+ # Get the prediction value
38
+ prediction = output[output_tensor_name].numpy()[0][0]
39
+ return prediction
40
+
41
+ # Generate Grad-CAM using the SavedModel
42
+ # Note: Grad-CAM is more complex with SavedModel format, so we'll use a simplified approach
43
+ def generate_attention_map(img_array, prediction):
44
+ # Since getting Grad-CAM from SavedModel is complex, let's use a simplified heatmap
45
+ # This is a placeholder - in production you may want to implement a proper CAM
46
+
47
+ # For demo purposes, we'll create a simple attention map based on image features
48
+ gray = cv2.cvtColor(img_array[0].astype(np.float32), cv2.COLOR_RGB2GRAY)
49
+ blur = cv2.GaussianBlur(gray, (5, 5), 0)
50
+
51
+ # Use simple edge detection as a proxy for "interesting" regions
52
+ sobelx = cv2.Sobel(blur, cv2.CV_64F, 1, 0, ksize=3)
53
+ sobely = cv2.Sobel(blur, cv2.CV_64F, 0, 1, ksize=3)
54
+ magnitude = np.sqrt(sobelx**2 + sobely**2)
55
+
56
+ # Normalize to 0-1
57
+ magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8)
58
+
59
+ # Apply sigmoid weighting based on prediction (higher probability = more intensity)
60
+ weight = 0.5 + (prediction - 0.5) * 0.5 # Scale between 0.5-1 based on prediction
61
+ magnitude = magnitude * weight
62
+
63
+ # Apply colormap
64
+ heatmap = cv2.applyColorMap(np.uint8(255 * magnitude), cv2.COLORMAP_JET)
65
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
66
+
67
+ return heatmap, magnitude
68
+
69
+ # Prediction function with visualization
70
+ def predict_and_explain(image):
71
+ if image is None:
72
+ return None, "Please upload an image.", 0.0
73
+
74
+ # Preprocess the image
75
+ preprocessed = preprocess_image(image)
76
+
77
+ # Make prediction
78
+ prediction = predict_with_saved_model(preprocessed)
79
+
80
+ # Generate attention map
81
+ heatmap, attention = generate_attention_map(preprocessed, prediction)
82
+
83
+ # Create overlay
84
+ original_resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
85
+ superimposed = (0.6 * original_resized) + (0.4 * heatmap)
86
+ superimposed = superimposed.astype(np.uint8)
87
+
88
+ # Create visualization with matplotlib
89
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
90
+
91
+ axes[0].imshow(original_resized)
92
+ axes[0].set_title("Original CT Scan")
93
+ axes[0].axis('off')
94
+
95
+ axes[1].imshow(heatmap)
96
+ axes[1].set_title("Feature Map")
97
+ axes[1].axis('off')
98
+
99
+ axes[2].imshow(superimposed)
100
+ axes[2].set_title(f"Overlay")
101
+ axes[2].axis('off')
102
+
103
+ # Add prediction information
104
+ result_text = f"{'Cancer' if prediction > 0.5 else 'Normal'} (Confidence: {abs(prediction if prediction > 0.5 else 1-prediction):.2%})"
105
+ fig.suptitle(result_text, fontsize=16)
106
+
107
+ # Convert plot to image
108
+ buf = io.BytesIO()
109
+ plt.tight_layout()
110
+ plt.savefig(buf, format='png')
111
+ plt.close(fig)
112
+ buf.seek(0)
113
+ result_image = np.array(Image.open(buf))
114
+
115
+ # Return prediction information
116
+ prediction_class = "Cancer" if prediction > 0.5 else "Normal"
117
+ confidence = float(prediction if prediction > 0.5 else 1-prediction)
118
+
119
+ return result_image, prediction_class, confidence
120
+
121
+ # Create Gradio interface
122
+ with gr.Blocks(title="Chest CT Scan Cancer Detection") as demo:
123
+ gr.Markdown("# Chest CT Scan Cancer Detection")
124
+ gr.Markdown("Upload a chest CT scan image to detect the presence of cancer.")
125
+
126
+ with gr.Row():
127
+ with gr.Column():
128
+ input_image = gr.Image(label="Upload CT Scan Image", type="numpy")
129
+ submit_btn = gr.Button("Analyze Image")
130
+
131
+ with gr.Column():
132
+ output_image = gr.Image(label="Analysis Results")
133
+ prediction_label = gr.Label(label="Prediction")
134
+ confidence_score = gr.Number(label="Confidence Score")
135
+
136
+ gr.Markdown("### How it works")
137
+ gr.Markdown("""
138
+ This application uses a deep learning model based on DenseNet121 architecture to classify chest CT scans as either 'Normal' or 'Cancer'.
139
+
140
+ The visualization shows:
141
+ - Left: Original CT scan
142
+ - Middle: Feature map highlighting areas with distinctive patterns
143
+ - Right: Overlay of the feature map on the original image
144
+
145
+ The model was trained on a dataset of chest CT scans containing normal images and various types of lung cancer (adenocarcinoma, squamous cell carcinoma, and large cell carcinoma).
146
+ """)
147
+
148
+ submit_btn.click(
149
+ predict_and_explain,
150
+ inputs=input_image,
151
+ outputs=[output_image, prediction_label, confidence_score]
152
+ )
153
+
154
+ demo.launch()