danielle2003 commited on
Commit
e7b2a60
·
verified ·
1 Parent(s): 358a241

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1545 -0
app.py ADDED
@@ -0,0 +1,1545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import streamlit as st
3
+
4
+ st.set_page_config(layout="wide")
5
+ import streamlit.components.v1 as components
6
+ import time
7
+ import numpy as np
8
+ import pandas as pd
9
+ import tensorflow as tf
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.cm as cm
12
+ from PIL import Image
13
+ from tf_keras_vis.gradcam import Gradcam
14
+ from io import BytesIO
15
+ from sklearn.metrics import classification_report,confusion_matrix, roc_curve, auc,precision_recall_curve, average_precision_score
16
+ from sklearn.preprocessing import label_binarize
17
+ import seaborn as sns
18
+ import torch
19
+ import torch.nn as nn
20
+ import torchvision.models as models
21
+ from torchvision import datasets, transforms
22
+ import torchvision.transforms as transforms
23
+ import torch.nn.functional as F
24
+ from gradcam import GradCAM # Import your GradCAM class
25
+
26
+
27
+
28
+ if "model" not in st.session_state:
29
+ st.session_state.model = tf.keras.models.load_model(
30
+ "models/best_model.h5"
31
+ )
32
+ if "framework" not in st.session_state:
33
+ st.session_state.framework = "Tensorflow"
34
+ if "menu" not in st.session_state:
35
+ st.session_state.menu = "1"
36
+ if st.session_state.menu =="1":
37
+ st.session_state.show_summary = True
38
+ st.session_state.show_arch = False
39
+ st.session_state.show_desc = False
40
+ elif st.session_state.menu =="2":
41
+ st.session_state.show_arch = True
42
+ st.session_state.show_summary = False
43
+ st.session_state.show_desc = False
44
+ elif st.session_state.menu =="3":
45
+ st.session_state.show_arch = False
46
+ st.session_state.show_summary = False
47
+ st.session_state.show_desc = True
48
+ else:
49
+ st.session_state.show_desc = True
50
+
51
+ import base64
52
+ import os
53
+ import tf_keras_vis
54
+
55
+ # ****************************************/
56
+ # GRAD CAM
57
+ # *********************************************#
58
+ if st.session_state.framework == "TensorFlow":
59
+ gradcam = Gradcam(st.session_state.model, model_modifier=None, clone=False)
60
+
61
+ def generate_gradcam(pil_image, target_class):
62
+ # Convert PIL to array and preprocess
63
+ img_array = np.array(pil_image)
64
+ img_preprocessed = tf.keras.applications.vgg16.preprocess_input(img_array.copy())
65
+ img_tensor = tf.expand_dims(img_preprocessed, axis=0)
66
+
67
+ # Generate heatmap
68
+ loss = lambda output: tf.reduce_mean(output[:, target_class])
69
+ cam = gradcam(loss, img_tensor, penultimate_layer=-1)
70
+
71
+ # Process heatmap
72
+ cam = cam
73
+ if cam.ndim > 2:
74
+ cam = cam.squeeze()
75
+ cam = np.maximum(cam, 0)
76
+ cam = cv2.resize(cam, (224, 224))
77
+ cam = cam / cam.max() if cam.max() > 0 else cam
78
+ return cam
79
+
80
+ if st.session_state.framework == "PyTorch":
81
+ target_layer = st.session_state.model.conv3 # Typically last convolutional layer
82
+ #gradcam = GradCAM(st.session_state.model, target_layer)
83
+ def preprocess_image(image):
84
+ preprocess = transforms.Compose([
85
+ transforms.Resize((224, 224)),
86
+ transforms.ToTensor()
87
+ ])
88
+ return preprocess(image).unsqueeze(0) # Add batch dimension
89
+
90
+ def generate_gradcams(image, target_class):
91
+ # Preprocess the image and convert it to a tensor
92
+ input_image = preprocess_image(image)
93
+
94
+ # Instantiate GradCAM
95
+ gradcampy = GradCAM(st.session_state.model, target_layer)
96
+
97
+ # Generate the CAM
98
+ cam = gradcampy.generate(input_image, target_class)
99
+
100
+ return cam
101
+ def convert_image_to_base64(pil_image):
102
+ buffered = BytesIO()
103
+ pil_image.save(buffered, format="PNG")
104
+ return base64.b64encode(buffered.getvalue()).decode()
105
+
106
+
107
+ #-------------------------------------------------
108
+ #loading pytorch
109
+ class KidneyCNN(nn.Module):
110
+ def __init__(self, num_classes=4):
111
+ super(KidneyCNN, self).__init__()
112
+
113
+ # Convolutional layers
114
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
115
+ self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
116
+ self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
117
+ self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
118
+
119
+ # Batch normalization layers
120
+ self.bn1 = nn.BatchNorm2d(32)
121
+ self.bn2 = nn.BatchNorm2d(64)
122
+ self.bn3 = nn.BatchNorm2d(128)
123
+ self.bn4 = nn.BatchNorm2d(256)
124
+
125
+ # Max pooling layers
126
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
127
+
128
+ # Fully connected layers
129
+ self.fc1 = nn.Linear(256 * 14 * 14, 512)
130
+ self.fc2 = nn.Linear(512, num_classes)
131
+
132
+ # Dropout for regularization
133
+ self.dropout = nn.Dropout(0.5)
134
+
135
+ def forward(self, x):
136
+ # Conv block 1
137
+ x = self.pool(F.relu(self.bn1(self.conv1(x))))
138
+
139
+ # Conv block 2
140
+ x = self.pool(F.relu(self.bn2(self.conv2(x))))
141
+
142
+ # Conv block 3
143
+ x = self.pool(F.relu(self.bn3(self.conv3(x))))
144
+ # Conv block 4
145
+ x = self.pool(F.relu(self.bn4(self.conv4(x))))
146
+
147
+ x = x.view(x.size(0), -1)
148
+
149
+ # Fully connected layers
150
+ x = self.dropout(F.relu(self.fc1(x)))
151
+ x = self.fc2(x)
152
+
153
+ return x
154
+
155
+
156
+
157
+
158
+
159
+ if st.session_state.framework =="PyTorch":
160
+ st.session_state.model = torch.load('models/kidney_model .pth', map_location=torch.device('cpu'))
161
+ st.session_state.model.eval()
162
+ print(type(st.session_state.model))
163
+
164
+
165
+ #*********************************************
166
+
167
+ # /#*********************************************/
168
+ # LOADING TEST DATASET
169
+
170
+ # *************************************************
171
+ if st.session_state.framework == "TensorFlow":
172
+ test_dir = "test"
173
+ BATCH_SIZE = 32
174
+ IMG_SIZE = (224, 224)
175
+ test_dataset = tf.keras.utils.image_dataset_from_directory(
176
+ test_dir, shuffle=False, batch_size=BATCH_SIZE, image_size=IMG_SIZE
177
+ )
178
+ class_names = test_dataset.class_names
179
+ def one_hot_encode(image, label):
180
+ label = tf.one_hot(label, num_classes)
181
+ return image, label
182
+ # One-hot encode labels using CategoryEncoding
183
+ class_labels = class_names
184
+
185
+
186
+
187
+ # One-hot encode labels using CategoryEncoding
188
+
189
+ # One-hot encode labels using CategoryEncoding
190
+ num_classes = len(class_names)
191
+
192
+ test_dataset = test_dataset.map(one_hot_encode)
193
+
194
+
195
+ elif st.session_state.framework == "PyTorch":
196
+ test_dir = "test"
197
+ BATCH_SIZE = 32
198
+ IMG_SIZE = (224, 224)
199
+ transform = transforms.Compose([
200
+ transforms.Resize((224, 224)),
201
+ transforms.ToTensor(),
202
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
203
+ ])
204
+
205
+ test_dataset = datasets.ImageFolder(root='test', transform=transform)
206
+ class_names = test_dataset.classes
207
+
208
+ # One-hot encode labels using CategoryEncoding
209
+ class_labels = class_names
210
+
211
+
212
+
213
+ # One-hot encode labels using CategoryEncoding
214
+
215
+ # One-hot encode labels using CategoryEncoding
216
+ num_classes = len(class_names)
217
+
218
+
219
+
220
+
221
+ #######################################################
222
+
223
+
224
+ # --------------------------------------------------#
225
+ class_labels = ["Cyst", "Normal", "Stone", "Tumor"]
226
+
227
+
228
+ def load_tensorflow_model():
229
+ tf_model = tf.keras.models.load_model("models/best_model.h5")
230
+ return tf_model
231
+
232
+ if st.session_state.framework =="TensorFlow":
233
+
234
+ def predict_image(image):
235
+ time.sleep(2)
236
+ image = image.resize((224, 224))
237
+ image = np.expand_dims(image, axis=0)
238
+ predictions = st.session_state.model.predict(image)
239
+ return predictions
240
+
241
+ if st.session_state.framework == "PyTorch":
242
+ logo_path = "images/pytorch.png"
243
+ bg_color = "#FF5733" # For example, a warm red/orange
244
+ bg_color_iv = "orange" # For example, a warm red/orange
245
+
246
+ model = "TENSORFLOW"
247
+
248
+
249
+ def predict_image(image):
250
+ # Preprocess the image to match the model input requirements
251
+ transform = transforms.Compose([
252
+ transforms.Resize((224, 224)),
253
+ transforms.ToTensor(),
254
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Standard VGG16 normalization
255
+ ])
256
+
257
+ image = transform(image).unsqueeze(0) # Add batch dimension
258
+
259
+ # Move image to the same device as the model (GPU or CPU)
260
+ image = image
261
+
262
+ # Set the model to evaluation mode
263
+ st.session_state.model.eval()
264
+
265
+ with torch.no_grad(): # Disable gradient calculation
266
+ outputs = st.session_state.model(image) # Forward pass
267
+
268
+ # Get predicted probabilities (softmax for multi-class)
269
+ if outputs.shape[1] == 1:
270
+ probs = torch.sigmoid(outputs) # Apply sigmoid activation for binary classification
271
+ prob_class_1 = probs[0].item() # Probability for class 1
272
+ prob_class_0 = 1 - prob_class_1 # Probability for class 0
273
+
274
+ # If the output has two units (binary classification with softmax)
275
+ else:
276
+ probs = torch.nn.functional.softmax(outputs, dim=1)
277
+ prob_class_0 = probs[0, 0].item()
278
+ prob_class_1 = probs[0, 1].item()
279
+ # Get the predicted class
280
+ print("Raw model output (logits):", outputs)
281
+
282
+ return prob_class_0, prob_class_1, probs
283
+ else:
284
+ logo_path = "images/tensorflow.png"
285
+ bg_color = "orange" # For example, a warm red/orange
286
+ bg_color_iv = "#FF5733" # For example, a warm red/orange
287
+
288
+ model = "PYTORCH"
289
+
290
+
291
+ #/*******************loading pytorch summary
292
+ def get_layers_data(model, prefix=""):
293
+ layers_data = []
294
+ for name, layer in model.named_children(): # Iterate over layers
295
+ full_name = f"{prefix}.{name}" if prefix else name # Track hierarchy
296
+
297
+ try:
298
+ shape = str(list(layer.parameters())[0].shape) # Get shape of the first param
299
+ except Exception:
300
+ shape = "N/A"
301
+
302
+ param_count = sum(p.numel() for p in layer.parameters()) # Count parameters
303
+
304
+ layers_data.append((full_name, layer.__class__.__name__, shape, f"{param_count:,}"))
305
+
306
+ # Recursively get layers inside this layer (for nested structures)
307
+ layers_data.extend(get_layers_data(layer, full_name))
308
+
309
+ return layers_data
310
+
311
+
312
+ ###########################################
313
+ main_bg_ext = "png"
314
+ main_bg = "images/bg1.jpg"
315
+ # Read and encode the logo image
316
+
317
+ with open(logo_path, "rb") as image_file:
318
+ encoded_logo = base64.b64encode(image_file.read()).decode()
319
+
320
+ # Custom CSS to style the logo above the sidebar
321
+ st.markdown(
322
+ f"""
323
+ <style>
324
+ /* Container for logo and text */
325
+ .logo-text-container {{
326
+ position: fixed;
327
+ top: 20px; /* Adjust vertical position */
328
+ left: 30px; /* Align with sidebar */
329
+ display: flex;
330
+ align-items: center;
331
+ gap: 5px;
332
+ width: 70%;
333
+ z-index:1000;
334
+ }}
335
+
336
+ /* Logo styling */
337
+ .logo-text-container img {{
338
+ width: 50px; /* Adjust logo size */
339
+ border-radius: 10px; /* Optional: round edges */
340
+ margin-top:-10px;
341
+ margin-left:-5px;
342
+
343
+
344
+ }}
345
+
346
+ /* Bold text styling */
347
+ .logo-text-container h1 {{
348
+ font-family: Nunito;
349
+ color: #0175C2;
350
+ font-size: 28px;
351
+ font-weight: bold;
352
+ margin-right :100px;
353
+ padding:0px;
354
+ }}
355
+ .logo-text-container i{{
356
+ font-family: Nunito;
357
+ color: {bg_color};
358
+ font-size: 15px;
359
+ margin-right :10px;
360
+ padding:0px;
361
+ margin-left:-18.5%;
362
+ margin-top:1%;
363
+ }}
364
+ /* Sidebar styling */
365
+ section[data-testid="stSidebar"][aria-expanded="true"] {{
366
+ margin-top: 100px !important; /* Space for the logo */
367
+ border-radius: 0 60px 0px 60px !important; /* Top-left and bottom-right corners */
368
+ width: 200px !important; /* Sidebar width */
369
+ background:none; /* Gradient background */
370
+ /* box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2); /* Shadow effect */
371
+ /* border: 1px solid #FFD700; /* Shiny golden border */
372
+ margin-bottom: 1px !important;
373
+ color:white !important;
374
+
375
+ }}
376
+ [class*="st-key-header"]{{
377
+
378
+ }}
379
+ header[data-testid="stHeader"] {{
380
+ /*background: transparent !important;*/
381
+ background: rgba(255, 255, 255, 0.05);
382
+ backdrop-filter: blur(10px);
383
+ /*margin-right: 10px !important;*/
384
+ margin-top: 0.5px !important;
385
+ z-index: 1 !important;
386
+
387
+ color: orange; /* White text */
388
+ font-family: "Times New Roman " !important; /* Font */
389
+ font-size: 18px !important; /* Font size */
390
+ font-weight: bold !important; /* Bold text */
391
+ padding: 10px 20px; /* Padding for buttons */
392
+ border: none; /* Remove border */
393
+ border-radius: 1px; /* Rounded corners */
394
+ box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.2); /* Shadow effect */
395
+ transition: all 0.3s ease-in-out; /* Smooth transition */
396
+ align-items: left;
397
+ justify-content: center;
398
+ /*margin: 10px 0;*/
399
+ width:100%;
400
+ height:80px;
401
+ backdrop-filter: blur(10px);
402
+ border: 2px solid rgba(255, 255, 255, 0.4); /* Light border */
403
+
404
+
405
+ }}
406
+ div[data-testid="stDecoration"]{{
407
+ background-image:none;
408
+ }}
409
+ div[data-testid="stApp"]{{
410
+ /*background: grey;*/
411
+ background: rgba(255, 255, 255, 0.5); /* Semi-transparent white background */
412
+
413
+ height: 100vh; /* Full viewport height */
414
+ width: 99.5%;
415
+ border-radius: 2px !important;
416
+ margin-left:5px;
417
+ margin-right:5px;
418
+ margin-top:0px;
419
+ /* box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.2); /* Shadow effect */
420
+
421
+
422
+ background: url(data:image/{main_bg_ext};base64,{base64.b64encode(open(main_bg, "rb").read()).decode()});
423
+ background-size: cover; /* Ensure the image covers the full page */
424
+ background-position: center;
425
+
426
+ overflow: hidden;
427
+
428
+ }}
429
+ .content-container {{
430
+ background: rgba(255, 255, 255, 0.05);
431
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
432
+ width: 28%;
433
+ margin-left: 150px;
434
+ /* margin-top: -60px;*/
435
+ margin-bottom: 10px;
436
+ margin-right:10px;
437
+ padding:0;
438
+ /* border-radius:0px 0px 15px 15px ;*/
439
+ border:1px solid transparent;
440
+ overflow-y: auto; /* Enable vertical scrolling for the content */
441
+ position: fixed; /* Fix the position of the container */
442
+ top: 10%; /* Adjust top offset */
443
+ left: 60%; /* Adjust left offset */
444
+ height: 89.5vh; /* Full viewport height */
445
+
446
+ }}
447
+ .content-container-principal img{{
448
+ margin-top:260px;
449
+ margin-left:30px;
450
+ }}
451
+
452
+ .content-container-principal
453
+ {{
454
+ background-color: rgba(173, 216, 230, 0.5); /* Light blue with 50% transparency */
455
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
456
+ width: 20%;
457
+ /* margin-top: -60px;*/
458
+ margin-bottom: 10px;
459
+ margin-right:10px;
460
+ margin:10px;
461
+ /* border-radius:0px 0px 15px 15px ;*/
462
+ border:1px solid transparent;
463
+ overflow-y: auto; /* Enable vertical scrolling for the content */
464
+ position: fixed; /* Fix the position of the container */
465
+ top: 7%; /* Adjust top offset */
466
+ /*left: 2%; Adjust left offset */
467
+ height: 84vh; /* Full viewport height */
468
+
469
+ }}
470
+ .content-container-principal-in
471
+ {{
472
+ background-color: rgba(173, 216, 230, 0.1); /* Light blue with 50% transparency */
473
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
474
+ width: 100%;
475
+ /* margin-top: -60px;*/
476
+
477
+ margin:1px;
478
+ /* border-radius:0px 0px 15px 15px ;*/
479
+ border:1px solid transparent;
480
+ overflow-y: auto; /* Enable vertical scrolling for the content */
481
+ position: fixed; /* Fix the position of the container */
482
+ height: 100.5vh; /* Full viewport height */
483
+ left:0%;
484
+ top:5%;
485
+
486
+ }}
487
+ div[data-testid="stText"] {{
488
+ background-color: transparent;
489
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
490
+ width: 132% !important;
491
+ background-color: rgba(173, 216, 230, 0.1); /* Light blue with 50% transparency */
492
+
493
+ margin-top: -36px;
494
+ margin-bottom: 10px;
495
+ margin-left:-220px !important;
496
+ padding:50px;
497
+ padding-bottom:20px;
498
+ padding-top:50px;
499
+ /* border-radius:0px 0px 15px 15px ;*/
500
+ border:1px solid transparent;
501
+ overflow-y: auto; /* Enable vertical scrolling for the content */
502
+ height: 85vh; !important; /* Full viewport height */
503
+
504
+ }}
505
+ .content-container2 {{
506
+ background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
507
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
508
+ width: 90%;
509
+ margin-left: 10px;
510
+ /* margin-top: -10px;*/
511
+ margin-bottom: 160px;
512
+ margin-right:10px;
513
+ padding:0;
514
+ border-radius:1px ;
515
+ border:1px solid transparent;
516
+ overflow-y: auto; /* Enable vertical scrolling for the content */
517
+ position: fixed; /* Fix the position of the container */
518
+ top: 3%; /* Adjust top offset */
519
+ left: 2.5%; /* Adjust left offset */
520
+ height: 78vh; /* Full viewport height */
521
+
522
+ }}
523
+ .content-container4 {{
524
+ background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
525
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%;
526
+ margin-left: 10px;
527
+ margin-bottom: 160px;
528
+ margin-right:10px;
529
+ padding:0;
530
+ overflow-y: auto; /* Enable vertical scrolling for the content */
531
+ position: fixed; /* Fix the position of the container */
532
+ top: 60%; /* Adjust top offset */
533
+ left: 2.5%; /* Adjust left offset */
534
+ height: 10vh; /* Full viewport height */
535
+
536
+ }}
537
+ .content-container4 h3 ,p {{
538
+ font-family: "Times New Roman" !important; /* Elegant font for title */
539
+ font-size: 1rem;
540
+ font-weight: bold;
541
+ text-align:center;
542
+ }}
543
+ .content-container5 h3 ,p {{
544
+ font-family: "Times New Roman" !important; /* Elegant font for title */
545
+ font-size: 1rem;
546
+ font-weight: bold;
547
+ text-align:center;
548
+ }}
549
+ .content-container6 h3 ,p {{
550
+ font-family: "Times New Roman" !important; /* Elegant font for title */
551
+ font-size: 1rem;
552
+ font-weight: bold;
553
+ text-align:center;
554
+ }}
555
+ .content-container7 h3 ,p {{
556
+ font-family: "Times New Roman" !important; /* Elegant font for title */
557
+ font-size: 1rem;
558
+ font-weight: bold;
559
+ text-align:center;
560
+ }}
561
+ .content-container5 {{
562
+ background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
563
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%;
564
+ margin-left: 180px;
565
+ margin-bottom: 130px;
566
+ margin-right:10px;
567
+ padding:0;
568
+ overflow-y: auto; /* Enable vertical scrolling for the content */
569
+ position: fixed; /* Fix the position of the container */
570
+ top: 60%; /* Adjust top offset */
571
+ left: 5.5%; /* Adjust left offset */
572
+ height: 10vh; /* Full viewport height */
573
+
574
+ }}
575
+ .content-container3 {{
576
+ background-color: rgba(216, 216, 230, 0.5); /* Light blue with 50% transparency */
577
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
578
+ width: 92%;
579
+ margin-left: 10px;
580
+ /* margin-top: -10px;*/
581
+ margin-bottom: 160px;
582
+ margin-right:10px;
583
+ padding:0;
584
+ border: 10px solid white;
585
+ overflow-y: auto; /* Enable vertical scrolling for the content */
586
+ position: fixed; /* Fix the position of the container */
587
+ top: 3%; /* Adjust top offset */
588
+ left: 1.5%; /* Adjust left offset */
589
+ height: 40vh; /* Full viewport height */
590
+
591
+ }}
592
+ .content-container6 {{
593
+ background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
594
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%;
595
+ margin-left: 10px;
596
+ margin-bottom: 160px;
597
+ margin-right:10px;
598
+ padding:0;
599
+ overflow-y: auto; /* Enable vertical scrolling for the content */
600
+ position: fixed; /* Fix the position of the container */
601
+ top: 80%; /* Adjust top offset */
602
+ left: 2.5%; /* Adjust left offset */
603
+ height: 10vh; /* Full viewport height */
604
+
605
+ }}
606
+ .content-container7 {{
607
+ background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
608
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%;
609
+ margin-left: 180px;
610
+ margin-bottom: 130px;
611
+ margin-right:10px;
612
+ padding:0;
613
+ overflow-y: auto; /* Enable vertical scrolling for the content */
614
+ position: fixed; /* Fix the position of the container */
615
+ top: 80%; /* Adjust top offset */
616
+ left: 5.5%; /* Adjust left offset */
617
+ height: 10vh; /* Full viewport height */
618
+
619
+ }}
620
+ .content-container2 img {{
621
+ width:99%;
622
+ height:50%;
623
+
624
+ }}
625
+ .content-container3 img {{
626
+ width:100%;
627
+ height:100%;
628
+
629
+ }}
630
+ div.stButton > button {{
631
+ background: rgba(255, 255, 255, 0.2);
632
+ color: orange !important; /* White text */
633
+ font-family: "Times New Roman " !important; /* Font */
634
+ font-size: 18px !important; /* Font size */
635
+ font-weight: bold !important; /* Bold text */
636
+ padding: 1px 2px; /* Padding for buttons */
637
+ border: none; /* Remove border */
638
+ border-radius: 5px; /* Rounded corners */
639
+ box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.2); /* Shadow effect */
640
+ transition: all 0.3s ease-in-out; /* Smooth transition */
641
+ display: flex;
642
+ align-items: left;
643
+ justify-content: left;
644
+ margin-left:-50px ;
645
+ width:250px;
646
+ height:50px;
647
+ backdrop-filter: blur(10px);
648
+ z-index:1000;
649
+ text-align: left; /* Align text to the left */
650
+ padding-left: 50px;
651
+
652
+
653
+ }}
654
+ div.stButton > button p{{
655
+ color: {bg_color} !important; /* White text */
656
+
657
+ }}
658
+ /* Hover effect */
659
+ div.stButton > button:hover {{
660
+ background: rgba(255, 255, 255, 0.2);
661
+ box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.4); /* Enhanced shadow on hover */
662
+ transform: scale(1.05); /* Slightly enlarge button */
663
+ transform: scale(1.1); /* Slight zoom on hover */
664
+ box-shadow: 0px 4px 12px rgba(255, 255, 255, 0.4); /* Glow effect */
665
+ }}
666
+ div.stButton > button:active {{
667
+ background: rgba(199, 107, 26, 0.5);
668
+ box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.4); /* Enhanced shadow on hover */
669
+
670
+ }}
671
+ .titles{{
672
+ margin-top:20px !important;
673
+ margin-left: -150px !important;
674
+
675
+ }}
676
+ /* Title styling */
677
+ .titles h1{{
678
+ /*font-family: "Times New Roman" !important; /* Elegant font for title */
679
+ font-size: 1.9rem;
680
+ /*font-weight: bold;*/
681
+ margin-left: 5px;
682
+ /* margin-top:-50px;*/
683
+ margin-bottom:50px;
684
+ padding: 0;
685
+ color: black; /* Neutral color for text */
686
+ }}
687
+ .titles > div{{
688
+ font-family: "Times New Roman" !important; /* Elegant font for title */
689
+ font-size: 1.01rem;
690
+ margin-left: -50px;
691
+ margin-bottom:1px;
692
+ padding: 0;
693
+ color:black; /* Neutral color for text */
694
+ }}
695
+ /* Recently viewed section */
696
+ .recently-viewed {{
697
+ display: flex;
698
+ align-items: center;
699
+ justify-content: flex-start; /* Align items to the extreme left */
700
+ margin-bottom: 10px;
701
+ margin-top: 20px;
702
+ gap: 10px; /* Add spacing between the elements */
703
+ padding-left: 20px; /* Add some padding if needed */
704
+ margin-left:35px;
705
+ height:100px;
706
+
707
+ }}
708
+
709
+
710
+
711
+
712
+
713
+ /* Style for the upload button */
714
+ [class*="st-key-upload-btn"] {{
715
+ position: absolute;
716
+ top: 100%; /* Position from the top of the inner circle */
717
+ left: -26%; /* Position horizontally at the center */
718
+ padding: 10px 20px;
719
+ color: red;
720
+ border: none;
721
+ border-radius: 20px;
722
+ cursor: pointer;
723
+ font-size: 35px !important;
724
+ width:30px;
725
+ height:20px;
726
+ }}
727
+
728
+ .upload-btn:hover {{
729
+ background-color: rgba(0, 123, 255, 1);
730
+ }}
731
+ div[data-testid="stFileUploader"] label > div > p {{
732
+ display:none;
733
+ color:white !important;
734
+ }}
735
+ section[data-testid="stFileUploaderDropzone"] {{
736
+ width:200px;
737
+ height: 60px;
738
+ background-color: white;
739
+ border-radius: 40px;
740
+ display: flex;
741
+ justify-content: center;
742
+ align-items: center;
743
+ margin-top:-10px;
744
+ box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.3);
745
+ margin:20px;
746
+ background-color: rgba(255, 255, 255, 0.7); /* Transparent blue background */
747
+ color:white;
748
+ }}
749
+ div[data-testid="stFileUploaderDropzoneInstructions"] div > small{{
750
+ color:white !important;
751
+ display:none;
752
+ }}
753
+ div[data-testid="stFileUploaderDropzoneInstructions"] span{{
754
+ margin-left:65px;
755
+ color:{bg_color};
756
+ }}
757
+ div[data-testid="stFileUploaderDropzoneInstructions"] div{{
758
+ display:none;
759
+ }}
760
+ section[data-testid="stFileUploaderDropzone"] button{{
761
+ display:none;
762
+ }}
763
+ div[data-testid="stMarkdownContainer"] p {{
764
+ font-family: "Times New Roman" !important; /* Elegant font for title */
765
+ color:white !important;
766
+ }}
767
+ .highlight {{
768
+ border: 4px solid lime;
769
+ font-weight: bold;
770
+ background: radial-gradient(circle, rgba(0,255,0,0.3) 0%, rgba(0,0,0,0) 70%);
771
+ box-shadow: 0px 0px 30px 10px rgba(0, 255, 0, 0.9),
772
+ 0px 0px 60px 20px rgba(0, 255, 0, 0.6),
773
+ inset 0px 0px 15px rgba(0, 255, 0, 0.8);
774
+ transition: all 0.3s ease-in-out;
775
+
776
+ }}
777
+ .highlight:hover {{
778
+ transform: scale(1.05);
779
+ background: radial-gradient(circle, rgba(0,255,0,0.6) 0%, rgba(0,0,0,0) 80%);
780
+ box-shadow: 0px 0px 40px 15px rgba(0, 255, 0, 1),
781
+ 0px 0px 70px 30px rgba(0, 255, 0, 0.7),
782
+ inset 0px 0px 20px rgba(0, 255, 0, 1);
783
+ }}
784
+ .stCheckbox > label > div{{
785
+ width:303px !important;
786
+ height:3rem;
787
+ margin-top:270px;
788
+ margin-left:-72px;
789
+ border-radius:1px !important;
790
+
791
+ }}
792
+ .st-b1 {{
793
+ width:1.75rem;
794
+ height:1.75rem;
795
+ display:none;
796
+ }}
797
+ .stCheckbox > label > div:after {{
798
+ content: "SWITCH TO {model} MODEL";
799
+ display: block;
800
+ font-family: "Times New Roman", serif;
801
+ margin-top: 0.5em;
802
+ margin-left:20px;
803
+ font-weight:bold;
804
+
805
+ }}
806
+ .st-bj{{
807
+ display:none;
808
+ }}
809
+ .stCheckbox label{{
810
+ height:0px;
811
+ }}
812
+ .stCheckbox > label > div {{
813
+ background:{bg_color_iv} !important;
814
+ }}
815
+ </style>
816
+ <div class="logo-text-container">
817
+ <img src="data:image/png;base64,{encoded_logo}" alt="Logo">
818
+ <h1>KidneyScan AI<br>
819
+
820
+ </h1>
821
+ <i>Empowering Early Diagnosis with AI</ai>
822
+
823
+
824
+ </div>
825
+ """,
826
+ unsafe_allow_html=True,
827
+ )
828
+ loading_html = """
829
+ <style>
830
+ .loader {
831
+ border: 8px solid #f3f3f3;
832
+ border-top: 8px solid #0175C2; /* Blue color */
833
+ border-radius: 50%;
834
+ width: 50px;
835
+ height: 50px;
836
+ animation: spin 1s linear infinite;
837
+ margin: auto;
838
+ }
839
+ @keyframes spin {
840
+ 0% { transform: rotate(0deg); }
841
+ 100% { transform: rotate(360deg); }
842
+ }
843
+
844
+ </style>
845
+ <div class="loader"></div>
846
+ """
847
+
848
+
849
+ # Sidebar content
850
+
851
+
852
+ # Use radio buttons for navigation
853
+ page = "pome"
854
+ # Sidebar buttons
855
+
856
+ # Display content based on the selected page
857
+ # Define the page content dynamically
858
+ if page == "Home":
859
+
860
+ # components.html(html_string) # JavaScript works
861
+ # st.markdown(html_string, unsafe_allow_html=True)
862
+ image_path = "images/image.jpg"
863
+
864
+ st.container()
865
+ st.markdown(
866
+ f"""
867
+
868
+ <div class="titles">
869
+ <h1>Kidney Disease Classfication</br> Using Transfer learning</h1>
870
+ <div> This web application utilizes deep learning to classify kidney ultrasound images</br>
871
+ into four categories: Normal, Cyst, Tumor, and Stone Class.
872
+ Built with Streamlit and powered by </br>a TensorFlow transfer learning
873
+ model based on <strong>VGG16</strong>
874
+ the app provides a simple and efficient way for users </br>
875
+ to upload kidney scans and receive instant predictions. The model analyzes the image
876
+ and classifies it based </br>on learned patterns, offering a confidence score for better interpretation.
877
+ </div>
878
+ </div>
879
+ """,
880
+ unsafe_allow_html=True,
881
+ )
882
+ uploaded_file = st.file_uploader(
883
+ "Choose a file", type=["png", "jpg", "jpeg"], key="upload-btn"
884
+ )
885
+ if uploaded_file is not None:
886
+ images = Image.open(uploaded_file)
887
+ # Rewind file pointer to the beginning
888
+ uploaded_file.seek(0)
889
+
890
+ file_content = uploaded_file.read() # Read file once
891
+ # Convert to base64 for HTML display
892
+ encoded_image = base64.b64encode(file_content).decode()
893
+ # Read and process image
894
+ pil_image = Image.open(uploaded_file).convert("RGB").resize((224, 224))
895
+ img_array = np.array(pil_image)
896
+
897
+ prediction = predict_image(images)
898
+ max_index = int(np.argmax(prediction[0]))
899
+ print(f"max index:{max_index}")
900
+ max_score = prediction[0][max_index]
901
+ predicted_class = np.argmax(prediction[0])
902
+
903
+ highlight_class = "highlight" # Special class for the highest confidence score
904
+
905
+ # Generate Grad-CAM
906
+ cam = generate_gradcam(pil_image, predicted_class)
907
+
908
+ # Create overlay
909
+ heatmap = cm.jet(cam)[..., :3]
910
+ heatmap = (heatmap * 255).astype(np.uint8)
911
+ overlayed_image = cv2.addWeighted(img_array, 0.6, heatmap, 0.4, 0)
912
+
913
+ # Convert to PIL
914
+ overlayed_pil = Image.fromarray(overlayed_image)
915
+ # Convert to base64
916
+ orig_b64 = convert_image_to_base64(pil_image)
917
+ overlay_b64 = convert_image_to_base64(overlayed_pil)
918
+ content = f"""
919
+ <div class="content-container">
920
+ <!-- Title -->
921
+ <!-- Recently Viewed Section -->
922
+ <div class="content-container2">
923
+ <div class="content-container3">
924
+ <img src="data:image/png;base64,{orig_b64}" alt="Uploaded Image">
925
+ </div>
926
+ <div class="content-container3">
927
+ <img src="data:image/png;base64,{overlay_b64}" class="result-image">
928
+ </div>
929
+ <div class="content-container4 {'highlight' if max_index == 0 else ''}">
930
+ <h3>{class_labels[0]}</h3>
931
+ <p>T Score: {prediction[0][0]:.2f}</p>
932
+ </div>
933
+ <div class="content-container5 {'highlight' if max_index == 1 else ''}">
934
+ <h3> {class_labels[1]}</h3>
935
+ <p>T Score: {prediction[0][1]:.2f}</p>
936
+ </div>
937
+ <div class="content-container6 {'highlight' if max_index == 2 else ''}">
938
+ <h3> {class_labels[2]}</h3>
939
+ <p>T Score: {prediction[0][2]:.2f}</p>
940
+ </div>
941
+ <div class="content-container7 {'highlight' if max_index == 3 else ''}">
942
+ <h3>{class_labels[3]}</h3>
943
+ <p>T Score: {prediction[0][3]:.2f}</p>
944
+ </div>
945
+
946
+
947
+ """
948
+
949
+ # Close the gallery and content div
950
+
951
+ # Render the content
952
+ placeholder = st.empty() # Create a placeholder
953
+ placeholder.markdown(loading_html, unsafe_allow_html=True)
954
+ time.sleep(5) # Wait for 5 seconds
955
+ placeholder.empty()
956
+ st.markdown(content, unsafe_allow_html=True)
957
+ else:
958
+ default_image_path = "images/image.jpg"
959
+ with open(image_path, "rb") as image_file:
960
+ encoded_image = base64.b64encode(image_file.read()).decode()
961
+
962
+ st.markdown(
963
+ f"""
964
+ <div class="content-container">
965
+ <!-- Title -->
966
+ <!-- Recently Viewed Section -->
967
+ <div class="content-container2">
968
+ <div class="content-container3">
969
+ <img src="data:image/png;base64,{encoded_image}" alt="Default Image">
970
+ </div>
971
+ </div>
972
+
973
+ """,
974
+ unsafe_allow_html=True,
975
+ )
976
+ if page == "pome":
977
+ gif_path = "images/bg3.gif"
978
+ with open(gif_path, "rb") as image_file:
979
+ encode_image = base64.b64encode(image_file.read()).decode()
980
+ st.markdown(
981
+ f"""
982
+
983
+ <div class="content-container-principal-in">
984
+ <div class="content-container-principal">
985
+ <img src="data:image/png;base64,{encode_image}" alt="Default Image">
986
+
987
+ </div>
988
+ </div>
989
+
990
+ """,
991
+ unsafe_allow_html=True,
992
+ )
993
+ col1, col2 = st.columns([1, 2]) # Adjust column widths
994
+ with col1:
995
+ if st.button("📄 Model Summary"):
996
+ st.session_state.menu ="1" # Store state
997
+ st.rerun()
998
+
999
+ # Add your model description logic here
1000
+
1001
+ if st.button("📊 Model Results Analysis",key="header"):
1002
+ st.session_state.menu ="2"
1003
+ st.rerun()
1004
+ # Add model analysis logic here
1005
+ if st.button("🧪 Model Testing"):
1006
+ st.session_state.menu ="3"
1007
+ st.rerun()
1008
+
1009
+
1010
+
1011
+
1012
+ # Toggle switch UI
1013
+ def framework_toggle():
1014
+ toggle = st.toggle("Enable PyTorch", value=(st.session_state.framework == "PyTorch"))
1015
+
1016
+ if toggle and st.session_state.framework != "PyTorch":
1017
+ st.session_state.framework = "PyTorch"
1018
+ st.session_state.model = torch.load('models/kidney_model .pth', map_location=torch.device('cpu'))
1019
+ st.rerun()
1020
+ elif not toggle and st.session_state.framework != "TensorFlow":
1021
+ st.session_state.framework = "TensorFlow"
1022
+ st.session_state.model = tf.keras.models.load_model(
1023
+ "models/best_model.h5"
1024
+ )
1025
+ st.rerun()
1026
+ print(st.session_state.framework)
1027
+
1028
+ framework_toggle()
1029
+
1030
+
1031
+ # Custom CSS for table styling
1032
+ table_style = """
1033
+ <style>
1034
+ table {
1035
+ width: 110%;
1036
+ border-collapse: collapse;
1037
+ border-radius: 2px;
1038
+ overflow: hidden;
1039
+ box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.4);
1040
+ background: rgba(255, 255, 255, 0.05);
1041
+ backdrop-filter: blur(10px);
1042
+ font-family: "Times New Roman", serif;
1043
+ margin-left:-100px;
1044
+ margin-top:10px;
1045
+ }
1046
+ thead {
1047
+ background: rgba(255, 255, 255, 0.2);
1048
+ }
1049
+ th {
1050
+ padding: 12px;
1051
+ text-align: left;
1052
+ font-weight: bold;
1053
+ backdrop-filter: blur(10px);
1054
+ }
1055
+ td {
1056
+ padding: 12px;
1057
+ border-bottom: 1px solid rgba(255, 255, 255, 0.1);
1058
+ }
1059
+ tr:hover {
1060
+ background-color: rgba(255, 255, 255, 0.1);
1061
+ }
1062
+ tbody {
1063
+ display: block;
1064
+ max-height: 580px; /* Set the fixed height */
1065
+ overflow-y: auto;
1066
+ width: 100%;
1067
+ }
1068
+ thead, tbody tr {
1069
+ display: table;
1070
+ width: 100%;
1071
+ table-layout: fixed;
1072
+ }
1073
+ </style>
1074
+ """
1075
+
1076
+ with col2:
1077
+ if st.session_state.show_summary:
1078
+ layers_data = []
1079
+ print(st.session_state)
1080
+ if st.session_state.framework == "TensorFlow":
1081
+ for layer in st.session_state.model.layers:
1082
+ try:
1083
+ shape = {layer.output.shape}
1084
+ except Exception:
1085
+ shape = "N/A"
1086
+
1087
+ if isinstance(shape, tuple):
1088
+ shape = str(shape)
1089
+ elif isinstance(shape, list):
1090
+ shape = ", ".join(str(s) for s in shape)
1091
+ elif shape is None:
1092
+ shape = "N/A"
1093
+
1094
+ param_count = f"{layer.count_params():,}"
1095
+
1096
+ layers_data.append(
1097
+ (layer.name, layer.__class__.__name__, shape, param_count)
1098
+ )
1099
+ print(layers_data)
1100
+
1101
+ elif st.session_state.framework == "PyTorch":
1102
+ layers_data = get_layers_data(st.session_state.model) # Get layer information
1103
+
1104
+
1105
+ # Convert to HTML table
1106
+ table_html = "<table><tr><th>Layer Name</th><th>Type</th><th>Output Shape</th><th>Param #</th></tr>"
1107
+ for name, layer_type, shape, params in layers_data:
1108
+ table_html += f"<tr><td>{name}</td><td>{layer_type}</td><td>{shape}</td><td>{params}</td></tr>"
1109
+ table_html += "</table>"
1110
+
1111
+ # Render table with custom styling
1112
+ st.markdown(table_style + table_html, unsafe_allow_html=True)
1113
+ if st.session_state.show_arch:
1114
+
1115
+ if st.session_state.framework == "TensorFlow":
1116
+ y_true = np.concatenate([y.numpy() for _, y in test_dataset])
1117
+
1118
+ # Get model predictions
1119
+ y_pred_probs = st.session_state.model.predict(test_dataset)
1120
+ y_pred = np.argmax(y_pred_probs, axis=1)
1121
+
1122
+ # Convert one-hot true labels to class indices
1123
+ y_true = np.argmax(y_true, axis=1)
1124
+
1125
+ # Class names (modify for your dataset)
1126
+ class_names = ["Cyst", "Normal", "Stone", "Tumor"]
1127
+
1128
+ # Generate classification report as a dictionary
1129
+ report_dict = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
1130
+
1131
+ # Convert to DataFrame
1132
+ report_df = pd.DataFrame(report_dict).transpose().round(2)
1133
+
1134
+ accuracy = report_dict["accuracy"]
1135
+ precision = report_df.loc["weighted avg", "precision"]
1136
+ recall = report_df.loc["weighted avg", "recall"]
1137
+ f1_score = report_df.loc["weighted avg", "f1-score"]
1138
+ elif st.session_state.framework == "PyTorch":
1139
+ y_true = []
1140
+ y_pred = []
1141
+ for image, label in test_dataset: # test_dataset is an instance of ImageFolder or similar
1142
+ image = image.unsqueeze(0) # Add batch dimension and move to device
1143
+ label = label
1144
+
1145
+ with torch.no_grad():
1146
+ output = st.session_state.model(image) # Get model output
1147
+ _, predicted = torch.max(output, 1) # Get predicted class
1148
+
1149
+ y_true.append(label) # Append true label
1150
+ y_pred.append(predicted.item()) # Append predicted label
1151
+
1152
+ # Generate the classification report
1153
+ report_dict = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
1154
+
1155
+ # Convert to DataFrame for better readability
1156
+ report_df = pd.DataFrame(report_dict).transpose().round(2)
1157
+
1158
+ accuracy = report_dict["accuracy"]
1159
+ precision = report_df.loc["weighted avg", "precision"]
1160
+ recall = report_df.loc["weighted avg", "recall"]
1161
+ f1_score = report_df.loc["weighted avg", "f1-score"]
1162
+
1163
+
1164
+
1165
+ st.markdown("""
1166
+ <style>
1167
+ .kpi-container {
1168
+ display: flex;
1169
+ justify-content: space-between;
1170
+ margin-bottom: 20px;
1171
+ margin-left:-80px;
1172
+ margin-top:-30px;
1173
+
1174
+ }
1175
+ .kpi-card {
1176
+ width: 23%;
1177
+ padding: 15px;
1178
+ text-align: center;
1179
+ border-radius: 10px;
1180
+ font-size: 22px;
1181
+ font-weight: bold;
1182
+ font-family: "Times New Roman " !important; /* Font */
1183
+ color: #333;
1184
+ background: rgba(255, 255, 255, 0.05);
1185
+ box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
1186
+ border: 5px solid rgba(173, 216, 230, 0.4);
1187
+ }
1188
+ </style>
1189
+ <div class="kpi-container">
1190
+ <div class="kpi-card">Precision<br>""" + f"{precision:.2f}" + """</div>
1191
+ <div class="kpi-card">Recall<br>""" + f"{recall:.2f}" + """</div>
1192
+ <div class="kpi-card">Accuracy<br>""" + f"{accuracy:.2f}" + """</div>
1193
+ <div class="kpi-card">F1-Score<br>""" + f"{f1_score:.2f}" + """</div>
1194
+ </div>
1195
+ """, unsafe_allow_html=True)
1196
+
1197
+
1198
+ # Remove last rows (accuracy/macro avg/weighted avg) and reset index
1199
+ report_df = report_df.iloc[:-3].reset_index()
1200
+ report_df.rename(columns={"index": "Class"}, inplace=True)
1201
+
1202
+ # Custom CSS for Table Styling
1203
+ st.markdown("""
1204
+ <style>
1205
+ .report-container {
1206
+ max-height: 250px;
1207
+ overflow-y: auto;
1208
+ border-radius: 25px;
1209
+ text-align:center;
1210
+ border: 5px solid rgba(173, 216, 230, 0.4);
1211
+ padding: 10px;
1212
+ background: rgba(255, 255, 255, 0.05);
1213
+ box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
1214
+ width:480px;
1215
+ margin-left:-80px;
1216
+ margin-top:-20px;
1217
+ }
1218
+ .report-container h4{
1219
+ font-family: "Times New Roman" !important; /* Elegant font for title */
1220
+ font-size: 1rem;
1221
+ margin-left: 5px;
1222
+ margin-bottom:1px;
1223
+ padding: 10px;
1224
+ color:#333;
1225
+
1226
+ }
1227
+ .report-table {
1228
+ width: 100%;
1229
+ border-collapse: collapse;
1230
+ font-family: 'Times New Roman', serif;
1231
+ text-align: center;
1232
+ }
1233
+ .report-table th {
1234
+ background: rgba(255, 255, 255, 0.05);
1235
+ font-size: 16px;
1236
+ padding: 10px;
1237
+ border-bottom: 2px solid #444;
1238
+ }
1239
+ .report-table td {
1240
+ font-size: 12px;
1241
+ padding: 10px;
1242
+ border-bottom: 1px solid #ddd;
1243
+ }
1244
+ </style>
1245
+ """, unsafe_allow_html=True)
1246
+ col1,col2 = st.columns([3,3])
1247
+ with col1:
1248
+ # Convert DataFrame to HTML Table
1249
+ report_html = report_df.to_html(index=False, classes="report-table", escape=False)
1250
+ st.markdown(f'<div class="report-container"><h4>classification report </h4>{report_html}</div>', unsafe_allow_html=True)
1251
+ # Generate Confusion Matrix
1252
+ # Generate Confusion Matrix
1253
+ cm = confusion_matrix(y_true, y_pred)
1254
+
1255
+ # Create Confusion Matrix Heatmap
1256
+ fig, ax = plt.subplots(figsize=(1, 1))
1257
+ fig.patch.set_alpha(0) # Make figure background transparent
1258
+
1259
+ # Seaborn Heatmap (Confusion Matrix)
1260
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
1261
+ xticklabels=class_names, yticklabels=class_names,
1262
+ linewidths=1, linecolor="black",
1263
+ cbar=False, square=True, alpha=0.9,
1264
+ annot_kws={"size": 5, "family": "Times New Roman"})
1265
+ # Change font for tick labels
1266
+ for text in ax.texts:
1267
+ text.set_bbox(dict(facecolor='none', edgecolor='none', alpha=0))
1268
+ plt.xticks(fontsize=4, family="Times New Roman") # X-axis font
1269
+ plt.yticks(fontsize=4, family="Times New Roman") # Y-axis font
1270
+ # Enhance Labels and Title
1271
+
1272
+ plt.title("Confusion Matrix", fontsize=5, family="Times New Roman",color="black", loc='center')
1273
+
1274
+ # Apply transparent background and double border (via Streamlit Markdown)
1275
+ st.markdown("""
1276
+ <style>
1277
+ div[data-testid="stImageContainer"] {
1278
+ max-height: 250px;
1279
+ overflow-y: auto;
1280
+ border-radius: 25px;
1281
+ text-align:center;
1282
+ border: 5px solid rgba(173, 216, 230, 0.4);
1283
+ padding: 10px;
1284
+ background: rgba(255, 255, 255, 0.05);
1285
+ box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
1286
+ width:480px !important;
1287
+ margin-left:-80px;
1288
+ margin-top:-20px;
1289
+
1290
+ }
1291
+ div[data-testid="stImageContainer"] img{
1292
+ margin-top:-10px !important;
1293
+ width:400px !important;
1294
+ height:250px !important;
1295
+ }
1296
+ [class*="st-key-roc"] div[data-testid="stImageContainer"] {
1297
+ max-height: 250px;
1298
+ overflow-y: auto;
1299
+ border-radius: 25px;
1300
+ text-align:center;
1301
+ border: 5px solid rgba(173, 216, 230, 0.4);
1302
+ background: rgba(255, 255, 255, 0.05);
1303
+ box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
1304
+ width:480px;
1305
+ margin-left:-35px;
1306
+ margin-top:-15px;
1307
+ }
1308
+ [class*="st-key-roc"] div[data-testid="stImageContainer"] img{
1309
+ width:480px !important;
1310
+ height:250px !important;
1311
+ margin-top:-20px !important;
1312
+
1313
+ }
1314
+ [class*="st-key-precision"] div[data-testid="stImageContainer"] {
1315
+ max-height: 250px;
1316
+ overflow-y: auto;
1317
+ border-radius: 25px;
1318
+ text-align:center;
1319
+ border: 5px solid rgba(173, 216, 230, 0.4);
1320
+ background: rgba(255, 255, 255, 0.05);
1321
+ box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
1322
+ width:480px;
1323
+ margin-left:-35px;
1324
+ margin-top:-5px;
1325
+ }
1326
+ [class*="st-key-precision"] div[data-testid="stImageContainer"] img{
1327
+ width:480px !important;
1328
+ height:250px !important;
1329
+ margin-top:-20px !important;
1330
+
1331
+ }
1332
+ </style>
1333
+ """, unsafe_allow_html=True)
1334
+
1335
+ # Show Plot in Streamlit inside a styled container
1336
+ st.markdown('<div class="confusion-matrix-container">', unsafe_allow_html=True)
1337
+ st.pyplot(fig)
1338
+ st.markdown("</div>", unsafe_allow_html=True)
1339
+
1340
+ with col2:
1341
+ if st.session_state.framework == "TensorFlow":
1342
+ # Binarizing the true labels for multi-class classification
1343
+ y_true_bin = label_binarize(y_true, classes=np.arange(len(class_names)))
1344
+
1345
+ # Calculating ROC curve and AUC for each class
1346
+ fpr, tpr, roc_auc = {}, {}, {}
1347
+
1348
+ for i in range(len(class_names)):
1349
+ fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred_probs[:, i])
1350
+ roc_auc[i] = auc(fpr[i], tpr[i])
1351
+
1352
+ # Plotting ROC curve for each class
1353
+ plt.figure(figsize=(11, 9))
1354
+
1355
+ for i in range(len(class_names)):
1356
+ plt.plot(fpr[i], tpr[i], lw=2, label=f'{class_names[i]} (AUC = {roc_auc[i]:.2f})')
1357
+
1358
+ # Plot random guess line
1359
+ plt.plot([0, 1], [0, 1], color='navy', lw=5, linestyle='--')
1360
+
1361
+ # Labels and legend
1362
+ plt.xlim([0.0, 1.0])
1363
+ plt.ylim([0.0, 1.05])
1364
+ plt.xlabel('False Positive Rate',fontsize=28,family="Times New Roman")
1365
+ plt.ylabel('True Positive Rate',fontsize=28,family="Times New Roman")
1366
+ plt.title('ROC Curve (One-vs-Rest) for Each Class',fontsize=30, family="Times New Roman",color="black", loc='center',pad=3)
1367
+ plt.legend(loc='lower right',fontsize=18)
1368
+ # Save the plot as an image
1369
+ plt.savefig('roc_curve.png', transparent=True)
1370
+ plt.close()
1371
+
1372
+ # Display the plot in Streamlit
1373
+ with st.container(key="roc"):
1374
+ st.image('roc_curve.png')
1375
+ elif st.session_state.framework == "PyTorch":
1376
+ # Display the ROC curve in Streamlit
1377
+ with st.container(key="roc"):
1378
+ st.image('roc-py.png')
1379
+
1380
+ with st.container(key="precision"):
1381
+ st.image('precision_recall_curve.png')
1382
+ if st.session_state.show_desc:
1383
+ # components.html(html_string) # JavaScript works
1384
+ # st.markdown(html_string, unsafe_allow_html=True)
1385
+ image_path = "images/image.jpg"
1386
+
1387
+ st.container()
1388
+ st.markdown(
1389
+ f"""
1390
+
1391
+ <div class="titles">
1392
+ <h1>Kidney Disease Classfication</br> Using Deep learning</h1>
1393
+ <div> This web application utilizes deep learning to classify kidney ultrasound images</br>
1394
+ into four categories: Normal, Cyst, Tumor, and Stone Class.
1395
+ Built with Streamlit and powered by </br>a TensorFlow transfer learning
1396
+ model based on <strong>VGG16</strong>
1397
+ the app provides a simple and efficient way for users </br>
1398
+ to upload kidney scans and receive instant predictions. The model analyzes the image
1399
+ and classifies it based </br>on learned patterns, offering a confidence score for better interpretation.
1400
+ </div>
1401
+ </div>
1402
+ """,
1403
+ unsafe_allow_html=True,
1404
+ )
1405
+ uploaded_file = st.file_uploader(
1406
+ "Choose a file", type=["png", "jpg", "jpeg"], key="upload-btn"
1407
+ )
1408
+ if uploaded_file is not None:
1409
+ images = Image.open(uploaded_file)
1410
+ # Rewind file pointer to the beginning
1411
+ uploaded_file.seek(0)
1412
+
1413
+ file_content = uploaded_file.read() # Read file once
1414
+ # Convert to base64 for HTML display
1415
+ encoded_image = base64.b64encode(file_content).decode()
1416
+ # Read and process image
1417
+ pil_image = Image.open(uploaded_file).convert("RGB").resize((224, 224))
1418
+ img_array = np.array(pil_image)
1419
+
1420
+ prediction = predict_image(images)
1421
+ if st.session_state.framework == "TensorFlow":
1422
+ max_index = int(np.argmax(prediction[0]))
1423
+ print(f"max index:{max_index}")
1424
+ max_score = prediction[0][max_index]
1425
+ predicted_class = np.argmax(prediction[0])
1426
+
1427
+ highlight_class = "highlight" # Special class for the highest confidence score
1428
+
1429
+ # Generate Grad-CAM
1430
+ cam = generate_gradcam(pil_image, predicted_class)
1431
+
1432
+ # Create overlay
1433
+ heatmap = cm.jet(cam)[..., :3]
1434
+ heatmap = (heatmap * 255).astype(np.uint8)
1435
+ overlayed_image = cv2.addWeighted(img_array, 0.6, heatmap, 0.4, 0)
1436
+
1437
+ # Convert to PIL
1438
+ overlayed_pil = Image.fromarray(overlayed_image)
1439
+ # Convert to base64
1440
+ orig_b64 = convert_image_to_base64(pil_image)
1441
+ overlay_b64 = convert_image_to_base64(overlayed_pil)
1442
+ content = f"""
1443
+ <div class="content-container">
1444
+ <!-- Title -->
1445
+ <!-- Recently Viewed Section -->
1446
+ <div class="content-container3">
1447
+ <img src="data:image/png;base64,{orig_b64}" alt="Uploaded Image">
1448
+ </div>
1449
+ <div class="content-container3">
1450
+ <img src="data:image/png;base64,{overlay_b64}" class="result-image">
1451
+ </div>
1452
+ <div class="content-container4 {'highlight' if max_index == 0 else ''}">
1453
+ <h3>{class_labels[0]}</h3>
1454
+ <p>T Score: {prediction[0][0]:.2f}</p>
1455
+ </div>
1456
+ <div class="content-container5 {'highlight' if max_index == 1 else ''}">
1457
+ <h3> {class_labels[1]}</h3>
1458
+ <p>T Score: {prediction[0][1]:.2f}</p>
1459
+ </div>
1460
+ <div class="content-container6 {'highlight' if max_index == 2 else ''}">
1461
+ <h3> {class_labels[2]}</h3>
1462
+ <p>T Score: {prediction[0][2]:.2f}</p>
1463
+ </div>
1464
+ <div class="content-container7 {'highlight' if max_index == 3 else ''}">
1465
+ <h3>{class_labels[3]}</h3>
1466
+ <p>T Score: {prediction[0][3]:.2f}</p>
1467
+ </div>
1468
+
1469
+
1470
+ """
1471
+ elif st.session_state.framework == "PyTorch":
1472
+ class0, class1,prediction = predict_image(images)
1473
+ max_index = int(np.argmax(prediction[0]))
1474
+ print(f"max index:{max_index}")
1475
+ max_score = prediction[0][max_index]
1476
+ predicted_class = np.argmax(prediction[0])
1477
+ print(f"predicted class is :{predicted_class}")
1478
+ #cams = generate_gradcams(pil_image, predicted_class)
1479
+ #heatmap = cm.jet(cams)[..., :3]
1480
+ #heatmap = (heatmap * 255).astype(np.uint8)
1481
+ #overlayed_image = cv2.addWeighted(img_array, 0.6, heatmap, 0.4, 0)
1482
+
1483
+ # Convert to PIL
1484
+ #overlayed_pil = Image.fromarray(overlayed_image)
1485
+ # Convert to base64
1486
+ orig_b64 = convert_image_to_base64(pil_image)
1487
+ #overlay_b64 = convert_image_to_base64(overlayed_pil)
1488
+ highlight_class = "highlight" # Special class for the highest confidence score
1489
+
1490
+ # Generate Grad-CAM
1491
+
1492
+ # Create overlay
1493
+
1494
+ orig_b64 = convert_image_to_base64(pil_image)
1495
+ content = f"""
1496
+ <div class="content-container">
1497
+ <!-- Title -->
1498
+ <!-- Recently Viewed Section -->
1499
+ <div class="content-container3">
1500
+ <img src="data:image/png;base64,{orig_b64}" alt="Uploaded Image">
1501
+ </div>
1502
+ <div class="content-container4 {'highlight' if max_index == 0 else ''}">
1503
+ <h3>{class_labels[0]}</h3>
1504
+ <p>T Score: {prediction[0][0]:.2f}</p>
1505
+ </div>
1506
+ <div class="content-container5 {'highlight' if max_index == 1 else ''}">
1507
+ <h3> {class_labels[1]}</h3>
1508
+ <p>T Score: {prediction[0][1]:.2f}</p>
1509
+ </div>
1510
+ <div class="content-container6 {'highlight' if max_index == 2 else ''}">
1511
+ <h3> {class_labels[2]}</h3>
1512
+ <p>T Score: {prediction[0][2]:.2f}</p>
1513
+ </div>
1514
+ <div class="content-container7 {'highlight' if max_index == 3 else ''}">
1515
+ <h3>{class_labels[3]}</h3>
1516
+ <p>T Score: {prediction[0][3]:.2f}</p>
1517
+ </div>
1518
+
1519
+
1520
+ """
1521
+
1522
+ # Render the content
1523
+ placeholder = st.empty() # Create a placeholder
1524
+ placeholder.markdown(loading_html, unsafe_allow_html=True)
1525
+ time.sleep(5) # Wait for 5 seconds
1526
+ placeholder.empty()
1527
+ st.markdown(content, unsafe_allow_html=True)
1528
+ else:
1529
+ default_image_path = "image.jpg"
1530
+ with open(image_path, "rb") as image_file:
1531
+ encoded_image = base64.b64encode(image_file.read()).decode()
1532
+
1533
+ st.markdown(
1534
+ f"""
1535
+ <div class="content-container">
1536
+ <!-- Title -->
1537
+ <!-- Recently Viewed Section -->
1538
+ <div class="content-container3">
1539
+ <img src="data:image/png;base64,{encoded_image}" alt="Default Image">
1540
+ </div>
1541
+ </div>
1542
+
1543
+ """,
1544
+ unsafe_allow_html=True,
1545
+ )