JohanBeytell commited on
Commit
b6208cf
Β·
verified Β·
1 Parent(s): 6b1e835

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +595 -0
app.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras.models import load_model
5
+ from tensorflow.keras.preprocessing.image import img_to_array, array_to_img
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib.cm as cm
8
+ import cv2
9
+ import gradio as gr
10
+ from PIL import Image
11
+ import io
12
+ import tempfile
13
+ from datetime import datetime
14
+
15
+ # Global variables
16
+ model = None
17
+ class_labels = {0: 'no', 1: 'yes'}
18
+ IMG_WIDTH, IMG_HEIGHT = 128, 128
19
+
20
+ # --- MODEL LOADING FUNCTION ---
21
+ def load_brain_tumor_model():
22
+ """Load the brain tumor detection model from the file system"""
23
+ global model
24
+
25
+ # Common model file names to check
26
+ model_paths = [
27
+ 'brain_tumor_classifier_v3.h5',
28
+ 'model.h5',
29
+ 'brain_tumor_model.h5',
30
+ 'brain_tumor_classifier.h5'
31
+ ]
32
+
33
+ for model_path in model_paths:
34
+ if os.path.exists(model_path):
35
+ try:
36
+ model = load_model(model_path)
37
+ print(f"βœ… Model loaded successfully from {model_path}")
38
+ return True
39
+ except Exception as e:
40
+ print(f"❌ Error loading model from {model_path}: {str(e)}")
41
+ continue
42
+
43
+ print("❌ No valid model file found. Please ensure your model is in the root directory.")
44
+ return False
45
+
46
+ # Load model on startup
47
+ model_loaded = load_brain_tumor_model()
48
+
49
+ # --- IMAGE PREPROCESSING FUNCTIONS ---
50
+ def preprocess_image(image, target_size=(128, 128)):
51
+ """
52
+ Preprocess uploaded image for model prediction
53
+ """
54
+ if image is None:
55
+ return None, "No image provided"
56
+
57
+ try:
58
+ # Convert to PIL Image if needed
59
+ if not isinstance(image, Image.Image):
60
+ image = Image.fromarray(image)
61
+
62
+ # Convert to RGB if needed
63
+ if image.mode != 'RGB':
64
+ image = image.convert('RGB')
65
+
66
+ # Resize image
67
+ image_resized = image.resize(target_size, Image.Resampling.LANCZOS)
68
+
69
+ # Convert to grayscale for display (optional)
70
+ image_gray = image_resized.convert('L').convert('RGB')
71
+
72
+ # Convert to array and normalize
73
+ img_array = img_to_array(image_resized) / 255.0
74
+
75
+ return image_resized, image_gray, img_array, "βœ… Image preprocessed successfully"
76
+
77
+ except Exception as e:
78
+ return None, None, None, f"❌ Error preprocessing image: {str(e)}"
79
+
80
+ # --- ENHANCED GRAD-CAM++ FUNCTIONS ---
81
+ def make_gradcampp_heatmap(img_array, model, last_conv_layer_name='last_conv_layer', pred_index=None):
82
+ """
83
+ Creates an improved Grad-CAM++ heatmap with better numerical stability.
84
+ """
85
+ if model is None:
86
+ return None
87
+
88
+ try:
89
+ grad_model = tf.keras.models.Model(
90
+ inputs=model.input,
91
+ outputs=[model.get_layer(last_conv_layer_name).output, model.output]
92
+ )
93
+
94
+ with tf.GradientTape(persistent=True) as tape1:
95
+ with tf.GradientTape(persistent=True) as tape2:
96
+ with tf.GradientTape() as tape3:
97
+ conv_outputs, predictions = grad_model(img_array)
98
+ if pred_index is None:
99
+ pred_index = tf.argmax(predictions[0])
100
+ class_channel = predictions[:, pred_index]
101
+
102
+ grads = tape3.gradient(class_channel, conv_outputs)
103
+ first_derivative = tape2.gradient(class_channel, conv_outputs)
104
+ second_derivative = tape1.gradient(first_derivative, conv_outputs)
105
+
106
+ del tape1, tape2
107
+
108
+ eps = 1e-8
109
+ alpha_num = second_derivative
110
+ alpha_denom = 2.0 * second_derivative + tf.reduce_sum(conv_outputs * grads, axis=[1, 2], keepdims=True)
111
+ alpha_denom = tf.where(tf.abs(alpha_denom) < eps, tf.ones_like(alpha_denom) * eps, alpha_denom)
112
+ alphas = alpha_num / alpha_denom
113
+
114
+ weights = tf.reduce_sum(alphas * tf.nn.relu(grads), axis=[1, 2])
115
+ weights = tf.nn.softmax(weights, axis=-1)
116
+
117
+ weights_reshaped = tf.reshape(weights, (1, 1, 1, -1))
118
+ heatmap = tf.reduce_sum(weights_reshaped * conv_outputs, axis=-1)
119
+ heatmap = tf.squeeze(heatmap, axis=0)
120
+
121
+ heatmap = tf.nn.relu(heatmap)
122
+ heatmap_np = heatmap.numpy()
123
+
124
+ heatmap_min = np.min(heatmap_np)
125
+ heatmap_max = np.max(heatmap_np)
126
+ if heatmap_max > heatmap_min:
127
+ heatmap_np = (heatmap_np - heatmap_min) / (heatmap_max - heatmap_min)
128
+ else:
129
+ heatmap_np = np.zeros_like(heatmap_np)
130
+
131
+ return heatmap_np
132
+
133
+ except Exception as e:
134
+ print(f"Error in Grad-CAM++: {str(e)}")
135
+ return None
136
+
137
+ def create_heatmap_visualizations(heatmap, img_shape):
138
+ """Create multiple heatmap visualizations with different color schemes"""
139
+ heatmap_resized = cv2.resize(heatmap, (img_shape[1], img_shape[0]), interpolation=cv2.INTER_CUBIC)
140
+ heatmap_smooth = cv2.GaussianBlur(heatmap_resized, (5, 5), 0)
141
+ heatmap_enhanced = cv2.equalizeHist(np.uint8(255 * heatmap_smooth)) / 255.0
142
+
143
+ visualizations = {
144
+ 'jet': {'heatmap': heatmap_smooth, 'colormap': 'jet', 'title': 'Jet Heatmap'},
145
+ 'hot': {'heatmap': heatmap_smooth, 'colormap': 'hot', 'title': 'Hot Heatmap'},
146
+ 'plasma': {'heatmap': heatmap_enhanced, 'colormap': 'plasma', 'title': 'Plasma Heatmap'},
147
+ 'viridis': {'heatmap': heatmap_enhanced, 'colormap': 'viridis', 'title': 'Viridis Heatmap'},
148
+ 'inferno': {'heatmap': heatmap_smooth, 'colormap': 'inferno', 'title': 'Inferno Heatmap'},
149
+ 'cool': {'heatmap': heatmap_smooth, 'colormap': 'cool', 'title': 'Cool Heatmap'}
150
+ }
151
+
152
+ return visualizations
153
+
154
+ def superimpose_gradcam_enhanced(img, heatmap, colormap='jet', alpha=0.4):
155
+ """Enhanced superimposition with different colormaps"""
156
+ if not isinstance(img, np.ndarray):
157
+ img = img_to_array(img)
158
+ if img.max() > 1.0:
159
+ img = img / 255.0
160
+
161
+ heatmap_resized = cv2.resize(heatmap, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC)
162
+ heatmap_uint8 = np.uint8(255 * heatmap_resized)
163
+
164
+ if hasattr(plt, 'colormaps'):
165
+ cmap = plt.colormaps[colormap]
166
+ else:
167
+ cmap = cm.get_cmap(colormap)
168
+
169
+ colored_heatmap = cmap(heatmap_uint8)[:, :, :3]
170
+
171
+ gamma = 2.2
172
+ img_gamma = np.power(img, 1/gamma)
173
+ colored_heatmap_gamma = np.power(colored_heatmap, 1/gamma)
174
+
175
+ blended_gamma = (colored_heatmap_gamma * alpha) + (img_gamma * (1 - alpha))
176
+ superimposed_img_float = np.power(blended_gamma, gamma)
177
+ superimposed_img_float = np.clip(superimposed_img_float, 0, 1)
178
+
179
+ return superimposed_img_float
180
+
181
+ # --- PREDICTION AND VISUALIZATION FUNCTIONS ---
182
+ def predict_brain_tumor(image):
183
+ """Make prediction on uploaded image"""
184
+ if not model_loaded or model is None:
185
+ return "❌ Model not available. Please check if the model file exists in the space.", None, None
186
+
187
+ if image is None:
188
+ return "❌ No image provided.", None, None
189
+
190
+ try:
191
+ # Preprocess image
192
+ processed_img, gray_img, img_array, preprocess_msg = preprocess_image(image)
193
+ if processed_img is None:
194
+ return preprocess_msg, None, None
195
+
196
+ # Make prediction
197
+ img_batch = np.expand_dims(img_array, axis=0)
198
+ prediction = model.predict(img_batch, verbose=0)[0][0]
199
+
200
+ # Interpret results
201
+ predicted_class = int(round(prediction))
202
+ predicted_label = class_labels[predicted_class]
203
+ confidence = prediction if predicted_class == 1 else 1 - prediction
204
+
205
+ # Create result message
206
+ if predicted_class == 1:
207
+ status_emoji = "⚠️"
208
+ status_text = "**TUMOR DETECTED**"
209
+ status_color = "red"
210
+ else:
211
+ status_emoji = "βœ…"
212
+ status_text = "**NO TUMOR DETECTED**"
213
+ status_color = "green"
214
+
215
+ result_msg = f"""
216
+ ## 🧠 Brain Tumor Detection Results
217
+
218
+ **Prediction:** {predicted_label.upper()}
219
+ **Confidence:** {confidence:.1%}
220
+ **Raw Score:** {prediction:.4f}
221
+
222
+ {status_emoji} {status_text}
223
+ """
224
+
225
+ return result_msg, processed_img, gray_img
226
+
227
+ except Exception as e:
228
+ return f"❌ Error during prediction: {str(e)}", None, None
229
+
230
+ def create_detailed_analysis(image):
231
+ """Create comprehensive Grad-CAM++ analysis"""
232
+ if not model_loaded or model is None or image is None:
233
+ return "❌ Please upload an image for analysis."
234
+
235
+ try:
236
+ # Preprocess and predict
237
+ processed_img, gray_img, img_array, _ = preprocess_image(image)
238
+ img_batch = np.expand_dims(img_array, axis=0)
239
+ prediction = model.predict(img_batch, verbose=0)[0][0]
240
+
241
+ predicted_class = int(round(prediction))
242
+ predicted_label = class_labels[predicted_class]
243
+ confidence = prediction if predicted_class == 1 else 1 - prediction
244
+
245
+ # Generate heatmap
246
+ heatmap = make_gradcampp_heatmap(img_batch, model)
247
+ if heatmap is None:
248
+ return "❌ Error generating heatmap."
249
+
250
+ # Create visualizations
251
+ visualizations = create_heatmap_visualizations(heatmap, img_array.shape)
252
+
253
+ # Create comprehensive plot
254
+ fig = plt.figure(figsize=(20, 12))
255
+ color = 'green' if predicted_class == 0 else 'red'
256
+ fig.suptitle(f'Comprehensive Grad-CAM++ Analysis\nPredicted: {predicted_label.upper()} ({confidence:.2%})',
257
+ fontsize=16, fontweight='bold', color=color)
258
+
259
+ # Original image
260
+ plt.subplot(3, 5, 1)
261
+ plt.imshow(processed_img)
262
+ plt.title("Original Image", fontsize=12, fontweight='bold')
263
+ plt.axis('off')
264
+
265
+ # Different heatmap visualizations
266
+ viz_names = ['jet', 'hot', 'plasma', 'viridis']
267
+ for i, viz_name in enumerate(viz_names):
268
+ viz = visualizations[viz_name]
269
+ plt.subplot(3, 5, i + 2)
270
+ im = plt.imshow(viz['heatmap'], cmap=viz['colormap'])
271
+ plt.title(viz['title'], fontsize=12)
272
+ plt.axis('off')
273
+ plt.colorbar(im, fraction=0.046, pad=0.04)
274
+
275
+ # More heatmap styles
276
+ viz_names2 = ['inferno', 'cool']
277
+ for i, viz_name in enumerate(viz_names2):
278
+ viz = visualizations[viz_name]
279
+ plt.subplot(3, 5, i + 6)
280
+ im = plt.imshow(viz['heatmap'], cmap=viz['colormap'])
281
+ plt.title(viz['title'], fontsize=12)
282
+ plt.axis('off')
283
+ plt.colorbar(im, fraction=0.046, pad=0.04)
284
+
285
+ # Attention profile
286
+ plt.subplot(3, 5, 8)
287
+ attention_profile = np.mean(heatmap, axis=1)
288
+ plt.plot(attention_profile, range(len(attention_profile)), 'b-', linewidth=2)
289
+ plt.title('Vertical Attention Profile', fontsize=12)
290
+ plt.xlabel('Attention Intensity')
291
+ plt.ylabel('Image Height')
292
+ plt.gca().invert_yaxis()
293
+ plt.grid(True, alpha=0.3)
294
+
295
+ # Statistics
296
+ plt.subplot(3, 5, 9)
297
+ stats_text = f"""Heatmap Statistics:
298
+ Mean: {np.mean(heatmap):.3f}
299
+ Std: {np.std(heatmap):.3f}
300
+ Max: {np.max(heatmap):.3f}
301
+ Min: {np.min(heatmap):.3f}
302
+
303
+ Prediction:
304
+ Confidence: {confidence:.1%}
305
+ Raw Score: {prediction:.4f}
306
+ Class: {predicted_label}"""
307
+
308
+ plt.text(0.1, 0.5, stats_text, transform=plt.gca().transAxes, fontsize=10,
309
+ verticalalignment='center', bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue"))
310
+ plt.axis('off')
311
+
312
+ # Superimposed views
313
+ superimposed_colormaps = ['jet', 'hot', 'plasma', 'viridis', 'inferno']
314
+ for i, cmap_name in enumerate(superimposed_colormaps):
315
+ superimposed_img = superimpose_gradcam_enhanced(img_array, heatmap, colormap=cmap_name, alpha=0.4)
316
+ plt.subplot(3, 5, i + 11)
317
+ plt.imshow(superimposed_img)
318
+ plt.title(f'Superimposed ({cmap_name.title()})', fontsize=12)
319
+ plt.axis('off')
320
+
321
+ plt.tight_layout()
322
+ plt.subplots_adjust(top=0.92)
323
+
324
+ # Save to temporary file and return
325
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
326
+ plt.savefig(temp_file.name, dpi=300, bbox_inches='tight')
327
+ plt.close()
328
+
329
+ return temp_file.name
330
+
331
+ except Exception as e:
332
+ return f"❌ Error creating detailed analysis: {str(e)}"
333
+
334
+ def create_quick_analysis(image):
335
+ """Create quick 2x3 comparison view"""
336
+ if not model_loaded or model is None or image is None:
337
+ return "❌ Please upload an image for analysis."
338
+
339
+ try:
340
+ # Preprocess and predict
341
+ processed_img, gray_img, img_array, _ = preprocess_image(image)
342
+ img_batch = np.expand_dims(img_array, axis=0)
343
+ prediction = model.predict(img_batch, verbose=0)[0][0]
344
+
345
+ predicted_class = int(round(prediction))
346
+ predicted_label = class_labels[predicted_class]
347
+ confidence = prediction if predicted_class == 1 else 1 - prediction
348
+
349
+ # Generate heatmap
350
+ heatmap = make_gradcampp_heatmap(img_batch, model)
351
+ if heatmap is None:
352
+ return "❌ Error generating heatmap."
353
+
354
+ # Create quick visualization
355
+ fig, axes = plt.subplots(2, 3, figsize=(15, 10))
356
+ color = 'green' if predicted_class == 0 else 'red'
357
+ fig.suptitle(f'Quick Grad-CAM++ Analysis | Predicted: {predicted_label.upper()} ({confidence:.2%})',
358
+ fontsize=14, fontweight='bold', color=color)
359
+
360
+ # Original image
361
+ axes[0, 0].imshow(processed_img)
362
+ axes[0, 0].set_title("Original Image")
363
+ axes[0, 0].axis('off')
364
+
365
+ # Jet heatmap
366
+ heatmap_resized = cv2.resize(heatmap, (IMG_WIDTH, IMG_HEIGHT))
367
+ im1 = axes[0, 1].imshow(heatmap_resized, cmap='jet')
368
+ axes[0, 1].set_title("Jet Heatmap")
369
+ axes[0, 1].axis('off')
370
+ plt.colorbar(im1, ax=axes[0, 1], fraction=0.046)
371
+
372
+ # Plasma heatmap
373
+ im2 = axes[0, 2].imshow(heatmap_resized, cmap='plasma')
374
+ axes[0, 2].set_title("Plasma Heatmap")
375
+ axes[0, 2].axis('off')
376
+ plt.colorbar(im2, ax=axes[0, 2], fraction=0.046)
377
+
378
+ # Superimposed views
379
+ superimposed_jet = superimpose_gradcam_enhanced(img_array, heatmap, 'jet')
380
+ axes[1, 0].imshow(superimposed_jet)
381
+ axes[1, 0].set_title("Superimposed (Jet)")
382
+ axes[1, 0].axis('off')
383
+
384
+ superimposed_hot = superimpose_gradcam_enhanced(img_array, heatmap, 'hot')
385
+ axes[1, 1].imshow(superimposed_hot)
386
+ axes[1, 1].set_title("Superimposed (Hot)")
387
+ axes[1, 1].axis('off')
388
+
389
+ superimposed_viridis = superimpose_gradcam_enhanced(img_array, heatmap, 'viridis')
390
+ axes[1, 2].imshow(superimposed_viridis)
391
+ axes[1, 2].set_title("Superimposed (Viridis)")
392
+ axes[1, 2].axis('off')
393
+
394
+ plt.tight_layout()
395
+
396
+ # Save to temporary file and return
397
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
398
+ plt.savefig(temp_file.name, dpi=300, bbox_inches='tight')
399
+ plt.close()
400
+
401
+ return temp_file.name
402
+
403
+ except Exception as e:
404
+ return f"❌ Error creating quick analysis: {str(e)}"
405
+
406
+ # --- GRADIO APP INTERFACE ---
407
+ def create_gradio_app():
408
+ """Create the main Gradio interface"""
409
+
410
+ # Custom CSS for better styling
411
+ custom_css = """
412
+ .gradio-container {
413
+ font-family: 'Arial', sans-serif;
414
+ }
415
+ .main-header {
416
+ text-align: center;
417
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
418
+ padding: 2rem;
419
+ border-radius: 10px;
420
+ color: white;
421
+ margin-bottom: 2rem;
422
+ }
423
+ .status-positive {
424
+ color: #22c55e;
425
+ font-weight: bold;
426
+ }
427
+ .status-negative {
428
+ color: #ef4444;
429
+ font-weight: bold;
430
+ }
431
+ """
432
+
433
+ with gr.Blocks(title="🧠 Brain Tumor Detection - Grad-CAM++", theme=gr.themes.Soft(), css=custom_css) as app:
434
+
435
+ gr.HTML("""
436
+ <div class="main-header">
437
+ <h1>🧠 Brain Tumor Detection with Enhanced Grad-CAM++</h1>
438
+ <p>Advanced AI-powered MRI analysis with explainable attention visualization</p>
439
+ </div>
440
+ """)
441
+
442
+ # Model status display
443
+ model_status = "βœ… Model loaded successfully" if model_loaded else "❌ Model not available"
444
+ gr.Markdown(f"**Model Status:** {model_status}")
445
+
446
+ if not model_loaded:
447
+ gr.Markdown("⚠️ **Warning**: Model file not found. Please ensure your trained model (.h5) is in the space's root directory.")
448
+
449
+ gr.Markdown("""
450
+ ## πŸ“– How to Use:
451
+ 1. **Upload an MRI brain scan** (JPEG, PNG, or other image formats)
452
+ 2. **View automatic preprocessing** and prediction results
453
+ 3. **Choose analysis type**: Quick for rapid assessment, Detailed for comprehensive visualization
454
+ 4. **Download results** for further analysis or documentation
455
+ """)
456
+
457
+ with gr.Row():
458
+ with gr.Column(scale=2):
459
+ input_image = gr.Image(
460
+ label="πŸ“€ Upload MRI Brain Scan",
461
+ type="pil",
462
+ height=400
463
+ )
464
+
465
+ with gr.Column(scale=1):
466
+ gr.Markdown("### πŸ”„ Preprocessing Preview")
467
+ processed_image = gr.Image(
468
+ label="Processed (128x128 RGB)",
469
+ height=180,
470
+ interactive=False
471
+ )
472
+ grayscale_image = gr.Image(
473
+ label="Grayscale Preview",
474
+ height=180,
475
+ interactive=False
476
+ )
477
+
478
+ # Prediction results
479
+ gr.Markdown("## 🎯 Prediction Results")
480
+ prediction_output = gr.Markdown(value="*Upload an image to see predictions...*")
481
+
482
+ # Analysis buttons
483
+ gr.Markdown("## πŸ”¬ Grad-CAM++ Analysis")
484
+ gr.Markdown("Choose your preferred analysis type:")
485
+
486
+ with gr.Row():
487
+ quick_btn = gr.Button(
488
+ "⚑ Quick Analysis (2x3 Grid)",
489
+ variant="secondary",
490
+ size="lg",
491
+ scale=1
492
+ )
493
+ detailed_btn = gr.Button(
494
+ "πŸ”¬ Detailed Analysis (3x5 Grid)",
495
+ variant="primary",
496
+ size="lg",
497
+ scale=1
498
+ )
499
+
500
+ # Analysis output
501
+ analysis_output = gr.Image(
502
+ label="πŸ“Š Analysis Results",
503
+ height=700,
504
+ interactive=False,
505
+ show_download_button=True
506
+ )
507
+
508
+ # Information sections
509
+ with gr.Row():
510
+ with gr.Column():
511
+ gr.Markdown("""
512
+ ### ⚑ Quick Analysis Features:
513
+ - **2x3 Grid Layout** for rapid evaluation
514
+ - **Original Image** with preprocessing
515
+ - **Jet & Plasma Heatmaps** with colorbars
516
+ - **3 Superimposed Views** (Jet, Hot, Viridis)
517
+ - **Fast Processing** (~2-3 seconds)
518
+ - **Perfect for screening** multiple images
519
+ """)
520
+
521
+ with gr.Column():
522
+ gr.Markdown("""
523
+ ### πŸ”¬ Detailed Analysis Features:
524
+ - **3x5 Grid Layout** for comprehensive analysis
525
+ - **6 Heatmap Color Schemes** with individual colorbars
526
+ - **Attention Profile Plot** showing vertical focus
527
+ - **Statistical Analysis Panel** with quantitative metrics
528
+ - **5 Enhanced Superimposed Views** with gamma correction
529
+ - **Clinical-grade visualization** for detailed examination
530
+ """)
531
+
532
+ gr.Markdown("""
533
+ ---
534
+ ### 🎨 Color Scheme Guide:
535
+ - **πŸ”₯ Jet**: Classic blue β†’ green β†’ yellow β†’ red progression (high contrast)
536
+ - **πŸŒ‹ Hot**: Black β†’ red β†’ orange β†’ yellow (heat-like visualization)
537
+ - **🌌 Plasma**: Purple β†’ pink β†’ yellow (scientifically accurate)
538
+ - **🌿 Viridis**: Dark blue β†’ green β†’ yellow (perceptually uniform)
539
+ - **πŸ”₯ Inferno**: Black β†’ purple β†’ red β†’ yellow (high contrast heat)
540
+ - **❄️ Cool**: Cyan β†’ blue β†’ magenta (cool color palette)
541
+
542
+ ### πŸ“Š Understanding the Results:
543
+ - **Bright regions** in heatmaps indicate areas the AI model focuses on
544
+ - **Different color schemes** can reveal different aspects of attention patterns
545
+ - **Confidence scores** above 80% are generally considered reliable
546
+ - **Superimposed views** help correlate AI attention with anatomical structures
547
+ """)
548
+
549
+ # Footer
550
+ gr.Markdown("""
551
+ ---
552
+ **⚠️ Medical Disclaimer**: This tool is for research and educational purposes only.
553
+ Always consult qualified medical professionals for clinical diagnosis and treatment decisions.
554
+ """)
555
+
556
+ # Event handlers
557
+ def predict_and_update(image):
558
+ result, processed, grayscale = predict_brain_tumor(image)
559
+ return result, processed, grayscale
560
+
561
+ def quick_analysis_handler(image):
562
+ if not model_loaded:
563
+ return None
564
+ return create_quick_analysis(image)
565
+
566
+ def detailed_analysis_handler(image):
567
+ if not model_loaded:
568
+ return None
569
+ return create_detailed_analysis(image)
570
+
571
+ # Connect event handlers
572
+ input_image.change(
573
+ fn=predict_and_update,
574
+ inputs=[input_image],
575
+ outputs=[prediction_output, processed_image, grayscale_image]
576
+ )
577
+
578
+ quick_btn.click(
579
+ fn=quick_analysis_handler,
580
+ inputs=[input_image],
581
+ outputs=[analysis_output]
582
+ )
583
+
584
+ detailed_btn.click(
585
+ fn=detailed_analysis_handler,
586
+ inputs=[input_image],
587
+ outputs=[analysis_output]
588
+ )
589
+
590
+ return app
591
+
592
+ # --- LAUNCH THE APP ---
593
+ if __name__ == "__main__":
594
+ app = create_gradio_app()
595
+ app.launch()