danielle2003 commited on
Commit
93c8930
·
verified ·
1 Parent(s): 2208322

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1381 -0
app.py ADDED
@@ -0,0 +1,1381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import barcode
2
+ import streamlit as st
3
+ st.set_page_config(layout="wide")
4
+ import streamlit.components.v1 as components
5
+ import cv2
6
+ from PIL import Image
7
+ import base64
8
+ import os
9
+ import time
10
+ import numpy as np
11
+ import pandas as pd
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib.cm as cm
14
+ from torch.utils.data import DataLoader
15
+ from PIL import Image
16
+ from io import BytesIO
17
+ from gradcam import GradCAM # Import your GradCAM class
18
+ from sklearn.metrics import classification_report,confusion_matrix, roc_curve, auc,precision_recall_curve, average_precision_score
19
+ from sklearn.preprocessing import label_binarize
20
+ import seaborn as sns
21
+ import torch
22
+ import torch.nn as nn
23
+ import torchvision.models as models
24
+ from torchvision import datasets, transforms
25
+ import torchvision.transforms as transforms
26
+
27
+ import io
28
+ import warnings
29
+ warnings.filterwarnings("ignore")
30
+ from barcode.writer import ImageWriter
31
+
32
+ showWarningOnDirectExecution = False
33
+ # Path to your logo image
34
+ logo_path = "images/pytorch.png"
35
+ main_bg_ext = 'png'
36
+ image = ["images/back1 (1).jpg","images/back1 (2).jpg","images/back1 (3).jpg","images/back1 (4).jpg","images/back1 (5).jpg"]
37
+ # Read and encode the logo image
38
+ with open(logo_path, "rb") as image_file:
39
+ encoded_logo = base64.b64encode(image_file.read()).decode()
40
+
41
+ if "framework" not in st.session_state:
42
+ st.session_state.framework = "Tensorflow"
43
+ if "menu" not in st.session_state:
44
+ st.session_state.menu = "1"
45
+ if st.session_state.menu =="1":
46
+ st.session_state.show_summary = True
47
+ st.session_state.show_arch = False
48
+ st.session_state.show_desc = False
49
+ elif st.session_state.menu =="2":
50
+ st.session_state.show_arch = True
51
+ st.session_state.show_summary = False
52
+ st.session_state.show_desc = False
53
+ elif st.session_state.menu =="3":
54
+ st.session_state.show_arch = False
55
+ st.session_state.show_summary = False
56
+ st.session_state.show_desc = True
57
+ else:
58
+ st.session_state.show_desc = True
59
+ def encode_image(image_path):
60
+ with open(image_path, "rb") as img_file:
61
+ return base64.b64encode(img_file.read()).decode()
62
+
63
+ #**************************************************
64
+ # loading pytorch model
65
+ #********************************************
66
+
67
+
68
+ # Define the CustomVGG16 model
69
+ class CustomVGG16(nn.Module):
70
+ def __init__(self, num_classes=2):
71
+ super(CustomVGG16, self).__init__()
72
+ base_model = models.vgg16(pretrained=False)
73
+ self.features = base_model.features
74
+ self.avgpool = nn.AdaptiveAvgPool2d((2, 2))
75
+ self.flatten = nn.Flatten()
76
+ self.fc1 = nn.Linear(512 * 2 * 2, 512)
77
+ self.bn1 = nn.BatchNorm1d(512)
78
+ self.dropout = nn.Dropout(0.5)
79
+ self.fc2 = nn.Linear(512, num_classes)
80
+ self.softmax = nn.Softmax(dim=1)
81
+
82
+ def forward(self, x):
83
+ x = self.features(x)
84
+ x = self.avgpool(x)
85
+ x = self.flatten(x)
86
+ x = self.fc1(x)
87
+ x = self.bn1(x)
88
+ x = torch.relu(x)
89
+ x = self.dropout(x)
90
+ x = self.fc2(x)
91
+ x = self.softmax(x)
92
+ return x
93
+
94
+ # Load the model
95
+ model = CustomVGG16(num_classes=2)
96
+
97
+ # Load the state_dict (weights only)
98
+ model.load_state_dict(torch.load('brain_model.pth', map_location=torch.device('cpu')))
99
+
100
+ model.eval()#model.eval() # Set the model to evaluation mode
101
+ target_layer = model.features[-1] # Typically last convolutional layer
102
+ gradcam = GradCAM(model, target_layer)
103
+ def preprocess_image(image):
104
+ preprocess = transforms.Compose([
105
+ transforms.Resize((224, 224)),
106
+ transforms.ToTensor(),
107
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # For pretrained models like VGG16
108
+ ])
109
+ return preprocess(image).unsqueeze(0) # Add batch dimension
110
+
111
+ def generate_gradcam(image, target_class):
112
+ # Preprocess the image and convert it to a tensor
113
+ input_image = preprocess_image(image)
114
+
115
+ # Instantiate GradCAM
116
+ gradcam = GradCAM(model, target_layer)
117
+
118
+ # Generate the CAM
119
+ cam = gradcam.generate(input_image, target_class)
120
+
121
+ return cam
122
+ # Function to get layer information
123
+ def get_layers_data(model, prefix=""):
124
+ layers_data = []
125
+ for name, layer in model.named_children(): # Iterate over layers
126
+ full_name = f"{prefix}.{name}" if prefix else name # Track hierarchy
127
+
128
+ try:
129
+ shape = str(list(layer.parameters())[0].shape) # Get shape of the first param
130
+ except Exception:
131
+ shape = "N/A"
132
+
133
+ param_count = sum(p.numel() for p in layer.parameters()) # Count parameters
134
+
135
+ layers_data.append((full_name, layer.__class__.__name__, shape, f"{param_count:,}"))
136
+
137
+ # Recursively get layers inside this layer (for nested structures)
138
+ layers_data.extend(get_layers_data(layer, full_name))
139
+
140
+ return layers_data
141
+
142
+ def convert_image_to_base64(pil_image):
143
+ buffered = BytesIO()
144
+ pil_image.save(buffered, format="PNG")
145
+ return base64.b64encode(buffered.getvalue()).decode()
146
+
147
+
148
+ def predict_image(image):
149
+ # Preprocess the image to match the model input requirements
150
+ transform = transforms.Compose([
151
+ transforms.Resize((224, 224)),
152
+ transforms.ToTensor(),
153
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Standard VGG16 normalization
154
+ ])
155
+
156
+ image = transform(image).unsqueeze(0) # Add batch dimension
157
+
158
+ # Move image to the same device as the model (GPU or CPU)
159
+ image = image
160
+
161
+ # Set the model to evaluation mode
162
+ model.eval()
163
+
164
+ with torch.no_grad(): # Disable gradient calculation
165
+ outputs = model(image) # Forward pass
166
+
167
+ # Get predicted probabilities (softmax for multi-class)
168
+ if outputs.shape[1] == 1:
169
+ probs = torch.sigmoid(outputs) # Apply sigmoid activation for binary classification
170
+ prob_class_1 = probs[0].item() # Probability for class 1
171
+ prob_class_0 = 1 - prob_class_1 # Probability for class 0
172
+
173
+ # If the output has two units (binary classification with softmax)
174
+ else:
175
+ probs = torch.nn.functional.softmax(outputs, dim=1)
176
+ prob_class_0 = probs[0, 0].item()
177
+ prob_class_1 = probs[0, 1].item()
178
+ # Get the predicted class
179
+ print("Raw model output (logits):", outputs)
180
+
181
+ return prob_class_0, prob_class_1, probs
182
+ # /#*********************************************/
183
+ # LOADING TEST DATASET
184
+
185
+ # *************************************************
186
+ test_dir = "test"
187
+ BATCH_SIZE = 32
188
+ IMG_SIZE = (224, 224)
189
+ transform = transforms.Compose([
190
+ transforms.Resize((224, 224)),
191
+ transforms.ToTensor(),
192
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
193
+ ])
194
+
195
+ test_dataset = datasets.ImageFolder(root='test', transform=transform)
196
+ test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
197
+ # One-hot encode labels using CategoryEncoding
198
+ class_names = test_dataset.classes
199
+ class_labels = class_names
200
+
201
+ # One-hot encode labels using CategoryEncoding
202
+ #num_classes = len(class_names)
203
+
204
+
205
+ #def one_hot_encode(image, label):
206
+ ##label = tf.one_hot(label, num_classes)
207
+ #return image, label
208
+
209
+
210
+ #test_dataset = test_dataset.map(one_hot_encode)
211
+
212
+
213
+ #######################################################
214
+ # Custom CSS to style the logo above the sidebar
215
+ st.markdown(
216
+ f"""
217
+ <style>
218
+ /* Container for logo and text */
219
+ .logo-text-container {{
220
+ position: fixed;
221
+ top: 30px; /* Adjust vertical position */
222
+ left: 50px; /* Align with sidebar */
223
+ display: flex;
224
+ align-items: center;
225
+ gap: 15px;
226
+ justify-content: space-between;
227
+ width: 100%;
228
+ }}
229
+
230
+ /* Logo styling */
231
+ .logo-text-container img {{
232
+ width: 100px; /* Adjust logo size */
233
+ border-radius: 10px; /* Optional: round edges */
234
+ margin-top:10px;
235
+ margin-left:20px;
236
+
237
+
238
+ }}
239
+
240
+ /* Bold text styling */
241
+ .logo-text-container h1 {{
242
+ font-family: 'Times New Roman', serif;
243
+ font-size: 24px;
244
+ font-weight: bold;
245
+ margin:-right 100px;;
246
+ text-align: center;
247
+ align-items: center;
248
+
249
+ margin: 0 auto; /* Center the text */
250
+
251
+ flex-grow:1;
252
+ color: #FFD700; /* Golden color for text */
253
+ }}
254
+ /* Sidebar styling */
255
+ section[data-testid="stSidebar"][aria-expanded="true"] {{
256
+ margin-top: 100px !important; /* Space for the logo */
257
+ border-radius: 0 60px 0px 60px !important; /* Top-left and bottom-right corners */
258
+ width: 200px !important; /* Sidebar width */
259
+ background:none; /* Gradient background */
260
+ /* box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2); /* Shadow effect */
261
+ /* border: 1px solid #FFD700; /* Shiny golden border */
262
+ margin-bottom: 1px !important;
263
+ color:white !important;
264
+
265
+ }}
266
+ /* Style for the upload button */
267
+ [class*="st-key-upload-btn"] {{
268
+ position: absolute;
269
+ top: 50%; /* Position from the top of the inner circle */
270
+ left: 1%; /* Position horizontally at the center */
271
+ padding: 10px 20px;
272
+ color: red;
273
+ border: none;
274
+ border-radius: 20px;
275
+ cursor: pointer;
276
+ font-size: 35px !important;
277
+ width:30px;
278
+ height:20px;
279
+ }}
280
+
281
+ .upload-btn:hover {{
282
+ background-color: rgba(0, 123, 255, 1);
283
+ }}
284
+ div[data-testid="stFileUploader"] label > div > p {{
285
+ display:none;
286
+ color:white !important;
287
+ }}
288
+ section[data-testid="stFileUploaderDropzone"] {{
289
+ width:200px;
290
+ height: 60px;
291
+ background-color: white;
292
+ border-radius: 40px;
293
+ display: flex;
294
+ justify-content: center;
295
+ align-items: center;
296
+ margin-top:-10px;
297
+ box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.3);
298
+ margin:20px;
299
+ background-color: rgba(255, 255, 255, 0.7); /* Transparent blue background */
300
+ color:white;
301
+ }}
302
+ div[data-testid="stFileUploaderDropzoneInstructions"] div > small{{
303
+ color:white !important;
304
+ display:none;
305
+ }}
306
+ div[data-testid="stFileUploaderDropzoneInstructions"] span{{
307
+ margin-left:65px;
308
+ color:orange;
309
+ }}
310
+ div[data-testid="stFileUploaderDropzoneInstructions"] div{{
311
+ display:none;
312
+ }}
313
+ section[data-testid="stFileUploaderDropzone"] button{{
314
+ display:none;
315
+ }}
316
+ div[data-testid="stMarkdownContainer"] p {{
317
+ font-family: "Times New Roman" !important; /* Elegant font for title */
318
+ color:white !important;
319
+ }}
320
+ .highlight {{
321
+ border: 4px solid lime;
322
+ font-weight: bold;
323
+ background: radial-gradient(circle, rgba(0,255,0,0.3) 0%, rgba(0,0,0,0) 70%);
324
+ box-shadow: 0px 0px 30px 10px rgba(0, 255, 0, 0.9),
325
+ 0px 0px 60px 20px rgba(0, 255, 0, 0.6),
326
+ inset 0px 0px 15px rgba(0, 255, 0, 0.8);
327
+ transition: all 0.3s ease-in-out;
328
+
329
+ }}
330
+ .highlight:hover {{
331
+ transform: scale(1.05);
332
+ background: radial-gradient(circle, rgba(0,255,0,0.6) 0%, rgba(0,0,0,0) 80%);
333
+ box-shadow: 0px 0px 40px 15px rgba(0, 255, 0, 1),
334
+ 0px 0px 70px 30px rgba(0, 255, 0, 0.7),
335
+ inset 0px 0px 20px rgba(0, 255, 0, 1);
336
+ }}
337
+ header[data-testid="stHeader"] {{
338
+ /* border-radius: 1px !important;*/
339
+ background: transparent !important; /* Gradient background */
340
+ /*box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2); /* Shadow effect */
341
+ /*: 3px solid #FFD700; /* Shiny golden border */
342
+ /*border-bottom:none !important;*/
343
+ margin-right: 100px !important;
344
+ margin-top: 32px !important;
345
+ z-index: 1 !important; /* Ensure it stays above other elements */
346
+
347
+ }}
348
+ div[data-testid="stDecoration"]{{
349
+ background-image:none;
350
+ }}
351
+ button[data-testid="stBaseButton-secondary"]{{
352
+ background:transparent;
353
+ border:none;
354
+ }}
355
+ div[data-testid="stApp"]{{
356
+ background:#161819;
357
+ height: 98vh; /* Full viewport height */
358
+ width: 98%;
359
+ border-radius: 40px !important;
360
+ margin-left:10px;
361
+ margin-right:10px;
362
+ margin-top:10px;
363
+ box-shadow: 0 4px 30px rgba(0, 0, 0, 0.5);
364
+
365
+ overflow: hidden;
366
+
367
+ }}
368
+ div[data-testid="stMarkdownContainer"] > p {{
369
+ font-family: "Times New Roman " !important; /* Font */
370
+ font-size: 11px !important; /* Font size */
371
+ margin:5px;
372
+ }}
373
+
374
+ [class*="st-key-content_"] {{
375
+
376
+ background: rgba(255, 255, 255, 0.9);
377
+ border-radius: 40px;
378
+ /* box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.1);*/
379
+ width: 83.7%;
380
+ margin-left: 75px;
381
+ /* margin-top: -70px;*/
382
+ margin-bottom: 10px;
383
+ margin-right:10px;
384
+ padding:0;
385
+
386
+ overflow-y: auto; /* Enable vertical scrolling for the content */
387
+ position: fixed; /* Fix the position of the container */
388
+ top: 1.5%; /* Adjust top offset */
389
+ left: 10%; /* Adjust left offset */
390
+ height: 98vh; /* Full viewport height */
391
+
392
+ }}
393
+
394
+ [class*="st-key-center-box"] {{
395
+
396
+ background-color: transparent;
397
+ border-radius: 60px;
398
+ width: 100%;
399
+ margin-top:30px;
400
+ top:20% !important; /* Adjust top offset */
401
+ left: 1%; /* Adjust left offset */
402
+
403
+ }}
404
+ [class*="st-key-side"] {{
405
+
406
+ background-color: transparent;
407
+ border-radius: 60px;
408
+ box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5);
409
+ width: 5%;
410
+ /* margin-top: 100px;*/
411
+ margin-bottom: 10px;
412
+ margin-right:10px;
413
+ padding:30px;
414
+ display: flex;
415
+ justify-content: center;
416
+
417
+ align-items: center;
418
+ overflow-y: auto; /* Enable vertical scrolling for the content */
419
+ position: fixed; /* Fix the position of the container */
420
+ top: 17%; /* Adjust top offset */
421
+ left: 16%; /* Adjust left offset */
422
+ height:50vh; /* Full viewport height */
423
+
424
+ }}
425
+
426
+ [class*="st-key-button_"] .stButton p > img {{
427
+ max-width: 100%;
428
+ vertical-align: top;
429
+ height:130px !important;
430
+ object-fit: cover;
431
+ padding: 10px;
432
+ width:250px !important;
433
+ border-radius:10px !important;
434
+ max-height: 2em !important;
435
+
436
+ }}
437
+ div.stButton > button {{
438
+ background: rgba(255, 255, 255, 0.2);
439
+ color: orange !important; /* White text */
440
+ font-family: "Times New Roman " !important; /* Font */
441
+ font-size: 18px !important; /* Font size */
442
+ font-weight: bold !important; /* Bold text */
443
+ padding: 1px 2px; /* Padding for buttons */
444
+ border: none; /* Remove border */
445
+ border-radius: 5px; /* Rounded corners */
446
+ box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.2); /* Shadow effect */
447
+ transition: all 0.3s ease-in-out; /* Smooth transition */
448
+ display: flex;
449
+ align-items: left;
450
+ justify-content: left;
451
+ margin-left:-15px ;
452
+ width:200px;
453
+ height:50px;
454
+ backdrop-filter: blur(10px);
455
+ z-index:1000;
456
+ text-align: left; /* Align text to the left */
457
+ padding-left: 50px;
458
+
459
+
460
+ }}
461
+ div.stButton > button p{{
462
+ color: white !important; /* White text */
463
+
464
+ }}
465
+ /* Hover effect */
466
+ div.stButton > button:hover {{
467
+ background: rgba(255, 255, 255, 0.2);
468
+ box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.4); /* Enhanced shadow on hover */
469
+ transform: scale(1.05); /* Slightly enlarge button */
470
+ transform: scale(1.1); /* Slight zoom on hover */
471
+ box-shadow: 0px 4px 12px rgba(255, 255, 255, 0.4); /* Glow effect */
472
+ }}
473
+ div.stButton > button:active {{
474
+ background: rgba(199, 107, 26, 0.5);
475
+ box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.4); /* Enhanced shadow on hover */
476
+
477
+ }}
478
+ div.stDownloadButton > button:active,
479
+ div.stDownloadButton > buttonfocus {{
480
+ background-color: transparent !important; /* or set it to the original background color */
481
+ outline: none; /* Remove the focus outline if you want */
482
+ }}
483
+ [class*="st-key-button_"] .stButton p > img {{
484
+ max-width: 100%;
485
+ vertical-align: top;
486
+ height:130px !important;
487
+ object-fit: cover;
488
+ padding: 10px;
489
+ width:250px !important;
490
+ border-radius:10px !important;
491
+ max-height: 2em !important;
492
+
493
+ }}
494
+ div.stDownloadButton > button > div > p {{
495
+ font-size:15px !important;
496
+ font-weight:bold;
497
+ }}
498
+ [class*="st-key-button_"] .stButton p{{
499
+ font-family: "Times New Roman " !important; /* Font */
500
+ font-size:100px !important;
501
+ height:150px !important;
502
+ box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2);
503
+ font-weight: bold;
504
+ margin-top:5px;
505
+ margin-left:5px;
506
+ color:black;
507
+ border-radius:10px;
508
+
509
+ }}
510
+ [class*="st-key-button_"]:hover {{
511
+
512
+ }}
513
+
514
+ [class*="st-key-nav-"] .stButton p{{
515
+ font-family: "Times New Roman " !important; /* Font */
516
+ font-size:1rem !important;
517
+ font-weight: bold;
518
+
519
+ }}
520
+ [class*="st-key-nav-10"]{{
521
+ border: none; /* Remove border */
522
+ background: transparent !important;
523
+ backdrop-filter: blur(10px) !important;
524
+ border-radius:80px !important;
525
+ width:180px !important;
526
+ height:100px; !important;
527
+ margin-top:35px !important;
528
+ }}
529
+ [class*="st-key-nav-6"]{{
530
+
531
+ border: none; /* Remove border */
532
+ background: transparent !important;
533
+ border-radius:80px !important;
534
+ backdrop-filter: blur(10px) !important;
535
+ border-radius:80px !important;
536
+ width:180px !important;
537
+ margin-top:35px !important;
538
+
539
+ }}
540
+ [class*="st-key-nav-6"] {{
541
+
542
+ border: none; /* Remove border */
543
+ background: transparent !important;
544
+ border-radius:80px !important;
545
+ backdrop-filter: blur(10px) !important;
546
+ border-radius:80px !important;
547
+ width:190px !important;
548
+ margin-top:35px !important;
549
+
550
+ }}
551
+ [class*="st-key-nav-12"],[class*="st-key-blur_"]{{
552
+
553
+ border: none; /* Remove border */
554
+ background: transparent !important;
555
+ border-radius:80px !important;
556
+ backdrop-filter: blur(10px) !important;
557
+ border-radius:80px !important;
558
+ width:180px !important;
559
+ margin-top:35px !important;
560
+
561
+ }}
562
+ [class*="st-key-nav-8"]{{
563
+
564
+ border: none; /* Remove border */
565
+ background: transparent !important;
566
+ border-radius:80px !important;
567
+ backdrop-filter: blur(10px) !important;
568
+ border-radius:80px !important;
569
+ width:300px !important;
570
+ height:80px; !important;
571
+ margin-top:35px !important;
572
+
573
+ }}
574
+ [class*="st-key-nav-5"]{{
575
+
576
+ border: none; /* Remove border */
577
+ background: transparent !important;
578
+ border-radius:80px !important;
579
+ backdrop-filter: blur(10px) !important;
580
+ border-radius:80px !important;
581
+ width:200px !important;
582
+ height:80px; !important;
583
+ margin-top:35px !important;
584
+
585
+ }}
586
+ [class*="st-key-nav-"],[class*="st-key-blur_"] {{
587
+ background: rgba(255, 255, 255, 0.2);
588
+ color: black; /* White text */
589
+ font-family: "Times New Roman " !important; /* Font */
590
+ font-size: 18px !important; /* Font size */
591
+ font-weight: bold !important; /* Bold text */
592
+ padding: 10px 20px; /* Padding for buttons */
593
+ border: none; /* Remove border */
594
+ border-radius: 15px; /* Rounded corners */
595
+ box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.2); /* Shadow effect */
596
+ transition: all 0.3s ease-in-out; /* Smooth transition */
597
+ display: flex;
598
+ align-items: center;
599
+ justify-content: center;
600
+ margin: 10px 0;
601
+ width:170px;
602
+ height:60px;
603
+ backdrop-filter: blur(10px);
604
+
605
+ }}
606
+
607
+ /* Hover effect */
608
+
609
+ [class*="st-key-nav-"]:hover {{
610
+ background: rgba(255, 255, 255, 0.2);
611
+ box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.4); /* Enhanced shadow on hover */
612
+ transform: scale(1.05); /* Slightly enlarge button */
613
+ transform: scale(1.1); /* Slight zoom on hover */
614
+ box-shadow: 0px 4px 12px rgba(255, 255, 255, 0.4); /* Glow effect */
615
+ }}
616
+
617
+
618
+ /* Title styling */
619
+ .title {{
620
+ font-family: "Times New Roman" !important; /* Elegant font for title */
621
+ font-size: 1.2rem;
622
+ font-weight: bold;
623
+ margin-left: 37px;
624
+ margin-top:10px;
625
+ margin-bottom:-100px;
626
+ padding: 0;
627
+ color: #333; /* Neutral color for text */
628
+ }}
629
+
630
+
631
+ .content-container {{
632
+ background: rgba(255, 255, 255, 0.05);
633
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
634
+ width: 28%;
635
+ margin-left: 150px;
636
+ /* margin-top: -60px;*/
637
+ margin-bottom: 10px;
638
+ margin-right:10px;
639
+ padding:0;
640
+ /* border-radius:0px 0px 15px 15px ;*/
641
+ border:1px solid transparent;
642
+ overflow-y: auto; /* Enable vertical scrolling for the content */
643
+ position: fixed; /* Fix the position of the container */
644
+ top: 10%; /* Adjust top offset */
645
+ left: 60%; /* Adjust left offset */
646
+ height: 89.5vh; /* Full viewport height */
647
+
648
+ }}
649
+ .content-container-principal img{{
650
+ margin-top:260px;
651
+ margin-left:30px;
652
+ }}
653
+
654
+
655
+ div[data-testid="stText"] {{
656
+ background-color: transparent;
657
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
658
+ width: 132% !important;
659
+ background-color: rgba(173, 216, 230, 0.1); /* Light blue with 50% transparency */
660
+
661
+ margin-top: -36px;
662
+ margin-bottom: 10px;
663
+ margin-left:-220px !important;
664
+ padding:50px;
665
+ padding-bottom:20px;
666
+ padding-top:50px;
667
+ /* border-radius:0px 0px 15px 15px ;*/
668
+ border:1px solid transparent;
669
+ overflow-y: auto; /* Enable vertical scrolling for the content */
670
+ height: 85vh; !important; /* Full viewport height */
671
+
672
+ }}
673
+ .content-container2 {{
674
+ background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
675
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
676
+ width: 90%;
677
+ margin-left: 10px;
678
+ /* margin-top: -10px;*/
679
+ margin-bottom: 160px;
680
+ margin-right:10px;
681
+ padding:0;
682
+ border-radius:1px ;
683
+ border:1px solid transparent;
684
+ overflow-y: auto; /* Enable vertical scrolling for the content */
685
+ position: fixed; /* Fix the position of the container */
686
+ top: 3%; /* Adjust top offset */
687
+ left: 2.5%; /* Adjust left offset */
688
+ height: 78vh; /* Full viewport height */
689
+
690
+ }}
691
+ .content-container4 {{
692
+ background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
693
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%;
694
+ margin-left: 10px;
695
+ margin-bottom: 160px;
696
+ margin-right:10px;
697
+ padding:0;
698
+ overflow-y: auto; /* Enable vertical scrolling for the content */
699
+ position: fixed; /* Fix the position of the container */
700
+ top: 60%; /* Adjust top offset */
701
+ left: 2.5%; /* Adjust left offset */
702
+ height: 10vh; /* Full viewport height */
703
+
704
+ }}
705
+ .content-container4 h3 ,p {{
706
+ font-family: "Times New Roman" !important; /* Elegant font for title */
707
+ font-size: 1rem;
708
+ font-weight: bold;
709
+ text-align:center;
710
+ }}
711
+ .content-container5 h3 ,p {{
712
+ font-family: "Times New Roman" !important; /* Elegant font for title */
713
+ font-size: 1rem;
714
+ font-weight: bold;
715
+ text-align:center;
716
+ }}
717
+ .content-container6 h3 ,p {{
718
+ font-family: "Times New Roman" !important; /* Elegant font for title */
719
+ font-size: 1rem;
720
+ font-weight: bold;
721
+ text-align:center;
722
+ }}
723
+ .content-container7 h3 ,p {{
724
+ font-family: "Times New Roman" !important; /* Elegant font for title */
725
+ font-size: 1rem;
726
+ font-weight: bold;
727
+ text-align:center;
728
+ }}
729
+ .content-container5 {{
730
+ background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
731
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%;
732
+ margin-left: 180px;
733
+ margin-bottom: 130px;
734
+ margin-right:10px;
735
+ padding:0;
736
+ overflow-y: auto; /* Enable vertical scrolling for the content */
737
+ position: fixed; /* Fix the position of the container */
738
+ top: 60%; /* Adjust top offset */
739
+ left: 5.5%; /* Adjust left offset */
740
+ height: 10vh; /* Full viewport height */
741
+
742
+ }}
743
+ .content-container3 {{
744
+ background-color: rgba(216, 216, 230, 0.5); /* Light blue with 50% transparency */
745
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
746
+ width: 92%;
747
+ margin-left: 10px;
748
+ /* margin-top: -10px;*/
749
+ margin-bottom: 160px;
750
+ margin-right:10px;
751
+ padding:0;
752
+ border: 10px solid white;
753
+ overflow-y: auto; /* Enable vertical scrolling for the content */
754
+ position: fixed; /* Fix the position of the container */
755
+ top: 3%; /* Adjust top offset */
756
+ left: 1.5%; /* Adjust left offset */
757
+ height: 40vh; /* Full viewport height */
758
+
759
+ }}
760
+ .content-container6 {{
761
+ background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
762
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%;
763
+ margin-left: 10px;
764
+ margin-bottom: 160px;
765
+ margin-right:10px;
766
+ padding:0;
767
+ overflow-y: auto; /* Enable vertical scrolling for the content */
768
+ position: fixed; /* Fix the position of the container */
769
+ top: 80%; /* Adjust top offset */
770
+ left: 2.5%; /* Adjust left offset */
771
+ height: 10vh; /* Full viewport height */
772
+
773
+ }}
774
+ .content-container7 {{
775
+ background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
776
+ backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%;
777
+ margin-left: 180px;
778
+ margin-bottom: 130px;
779
+ margin-right:10px;
780
+ padding:0;
781
+ overflow-y: auto; /* Enable vertical scrolling for the content */
782
+ position: fixed; /* Fix the position of the container */
783
+ top: 80%; /* Adjust top offset */
784
+ left: 5.5%; /* Adjust left offset */
785
+ height: 10vh; /* Full viewport height */
786
+
787
+ }}
788
+ .content-container2 img {{
789
+ width:99%;
790
+ height:50%;
791
+
792
+ }}
793
+ .content-container3 img {{
794
+ width:100%;
795
+ height:100%;
796
+
797
+ }}
798
+
799
+ .side_box{{
800
+ width: 200px;
801
+ height: 180px;
802
+ background-color: #0175C2;
803
+ margin: 5px;
804
+ border-radius:20px;
805
+ left:-5%;
806
+ }}
807
+ .titles{{
808
+ margin-top:20px !important;
809
+ margin-left: -150px !important;
810
+
811
+ }}
812
+ /* Title styling */
813
+ .titles h1{{
814
+ /*font-family: "Times New Roman" !important; /* Elegant font for title */
815
+ font-size: 2.2rem;
816
+ /*font-weight: bold;*/
817
+ margin-left: 0px;
818
+ margin-top:80px;
819
+ margin-bottom:30px;
820
+ padding: 0;
821
+ color: black; /* Neutral color for text */
822
+ }}
823
+ .titles > div{{
824
+ font-family: "Times New Roman" !important; /* Elegant font for title */
825
+ font-size: 1.2rem;
826
+ margin-left: 200px;
827
+ margin-bottom:1px;
828
+ padding: 0;
829
+ color:black; /* Neutral color for text */
830
+ }}
831
+ </style>
832
+ <div class="logo-text-container">
833
+ <img src="data:image/png;base64,{encoded_logo}" alt="Logo">
834
+ </div>
835
+ """, unsafe_allow_html=True
836
+ )
837
+ loading_html = """
838
+ <style>
839
+ .loader {
840
+ border: 8px solid #f3f3f3;
841
+ border-top: 8px solid #0175C2; /* Blue color */
842
+ border-radius: 50%;
843
+ width: 50px;
844
+ height: 50px;
845
+ animation: spin 1s linear infinite;
846
+ margin: auto;
847
+ }
848
+ @keyframes spin {
849
+ 0% { transform: rotate(0deg); }
850
+ 100% { transform: rotate(360deg); }
851
+ }
852
+
853
+ </style>
854
+ <div class="loader"></div>
855
+ """
856
+ # Sidebar content
857
+ st.markdown(
858
+ """
859
+ <style>
860
+ .sidebar-desc {
861
+ font-family: "Times New Roman" !important; /* Elegant font for title */ font-size: 14px;
862
+ color: #333;
863
+ background-color: transparent;
864
+ padding: 15px;
865
+ border-radius: 20px;
866
+ box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
867
+ width:200px !important;
868
+ margin-left:-15px;
869
+ margin-top:-50px;
870
+ height:70vh;
871
+ }
872
+ .sidebar-desc h3 {
873
+ font-family: "Times New Roman" !important; /* Elegant font for title */ font-size: 14px;
874
+
875
+ font-size: 18px;
876
+ color: #0175C2; /* Light Blue */
877
+ margin-bottom: 10px;
878
+ }
879
+ .sidebar-desc h4 {
880
+ font-size: 16px;
881
+ color: #444;
882
+ margin-bottom: 5px;
883
+ font-family: "Times New Roman" !important; /* Elegant font for title */ font-size: 14px;
884
+
885
+ }
886
+ .sidebar-desc ul {
887
+ list-style-type: square;
888
+ margin: 0;
889
+ padding-left: 20px;
890
+ }
891
+ .sidebar-desc ul li {
892
+ margin-bottom: 5px;
893
+ }
894
+ .sidebar-desc a {
895
+ color: #0175C2;
896
+ text-decoration: none;
897
+ }
898
+ .sidebar-desc a:hover {
899
+ text-decoration: underline;
900
+ }
901
+ </style>
902
+ """,
903
+ unsafe_allow_html=True,
904
+ )
905
+
906
+
907
+ # Use radio buttons for navigation
908
+ # Set the page to "Home"
909
+ page = "Home"
910
+ selected_img =""
911
+
912
+
913
+ st.session_state.page = "Home"
914
+ # Display content based on the selected page
915
+ if st.session_state.page == "Home":
916
+ # Sidebar buttons
917
+ with st.sidebar:
918
+ if st.button("📄 Model Summary"):
919
+ st.session_state.menu ="1" # Store state
920
+ st.rerun()
921
+
922
+ # Add your model description logic here
923
+
924
+ if st.button("📊 Model Results Analysis",key="header"):
925
+ st.session_state.menu ="2"
926
+ st.rerun()
927
+ # Add model analysis logic here
928
+ if st.button("🧪 Model Testing"):
929
+ st.session_state.menu ="3"
930
+ st.rerun()
931
+
932
+
933
+
934
+ table_style = """
935
+ <style>
936
+ table {
937
+ width: 100%;
938
+ border-collapse: collapse;
939
+ border-radius: 2px;
940
+ overflow: hidden;
941
+ box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1);
942
+ background: rgba(255, 255, 255, 0.05);
943
+ backdrop-filter: blur(10px);
944
+ font-family: "Times New Roman", serif;
945
+ margin-left:100px;
946
+ margin-top:30px;
947
+ }
948
+ thead {
949
+ background: rgba(255, 255, 255, 0.2);
950
+ }
951
+ th {
952
+ padding: 12px;
953
+ text-align: left;
954
+ font-weight: bold;
955
+ backdrop-filter: blur(10px);
956
+ }
957
+ td {
958
+ padding: 12px;
959
+ border-bottom: 1px solid rgba(255, 255, 255, 0.1);
960
+ }
961
+ tr:hover {
962
+ background-color: rgba(255, 255, 255, 0.1);
963
+ }
964
+ tbody {
965
+ display: block;
966
+ max-height: 680px; /* Set the fixed height */
967
+ overflow-y: auto;
968
+ width: 100%;
969
+ }
970
+ thead, tbody tr {
971
+ display: table;
972
+ width: 100%;
973
+ table-layout: fixed;
974
+ }
975
+ </style>
976
+ """
977
+ print(test_loader)
978
+
979
+ with st.container(key="content_1"):
980
+ print(type(model)) # Should print <class 'CustomVGG16'> and not OrderedDict
981
+ if st.session_state.show_summary:
982
+ # Load the model
983
+ layers_data = get_layers_data(model) # Get layer information
984
+
985
+ # Convert to HTML table
986
+ table_html = "<table><tr><th>Layer Name</th><th>Type</th><th>Output Shape</th><th>Param #</th></tr>"
987
+ for name, layer_type, shape, params in layers_data:
988
+ table_html += f"<tr><td>{name}</td><td>{layer_type}</td><td>{shape}</td><td>{params}</td></tr>"
989
+ table_html += "</table>"
990
+
991
+
992
+ st.markdown(table_style + table_html, unsafe_allow_html=True)
993
+
994
+ if st.session_state.show_arch:
995
+ model.eval()
996
+
997
+ # Initialize lists to store true labels and predicted labels
998
+ y_true = []
999
+ y_pred = []
1000
+ for image, label in test_dataset: # test_dataset is an instance of ImageFolder or similar
1001
+ image = image.unsqueeze(0) # Add batch dimension and move to device
1002
+ label = label
1003
+
1004
+ with torch.no_grad():
1005
+ output = model(image) # Get model output
1006
+ _, predicted = torch.max(output, 1) # Get predicted class
1007
+
1008
+ y_true.append(label) # Append true label
1009
+ y_pred.append(predicted.item()) # Append predicted label
1010
+
1011
+ # Generate the classification report
1012
+ report_dict = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
1013
+
1014
+ # Convert to DataFrame for better readability
1015
+ report_df = pd.DataFrame(report_dict).transpose().round(2)
1016
+
1017
+ accuracy = report_dict["accuracy"]
1018
+ precision = report_df.loc["weighted avg", "precision"]
1019
+ recall = report_df.loc["weighted avg", "recall"]
1020
+ f1_score = report_df.loc["weighted avg", "f1-score"]
1021
+
1022
+
1023
+ st.markdown("""
1024
+ <style>
1025
+ .kpi-container {
1026
+ display: flex;
1027
+ justify-content: space-between;
1028
+ margin-bottom: 20px;
1029
+ margin-left:100px;
1030
+ margin-top:70px;
1031
+
1032
+ }
1033
+ .kpi-card {
1034
+ width: 23%;
1035
+ padding: 15px;
1036
+ text-align: center;
1037
+ border-radius: 10px;
1038
+ font-size: 22px;
1039
+ font-weight: bold;
1040
+ font-family: "Times New Roman " !important; /* Font */
1041
+ color: #333;
1042
+ background: rgba(255, 255, 255, 0.05);
1043
+ box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
1044
+ border: 5px solid rgba(173, 216, 230, 0.4);
1045
+ }
1046
+ </style>
1047
+ <div class="kpi-container">
1048
+ <div class="kpi-card">Precision<br>""" + f"{precision:.2f}" + """</div>
1049
+ <div class="kpi-card">Recall<br>""" + f"{recall:.2f}" + """</div>
1050
+ <div class="kpi-card">Accuracy<br>""" + f"{accuracy:.3f}" + """</div>
1051
+ <div class="kpi-card">F1-Score<br>""" + f"{f1_score:.3f}" + """</div>
1052
+ </div>
1053
+ """, unsafe_allow_html=True)
1054
+
1055
+
1056
+ # Remove last rows (accuracy/macro avg/weighted avg) and reset index
1057
+ report_df = report_df.iloc[:-3].reset_index()
1058
+ report_df.rename(columns={"index": "Class"}, inplace=True)
1059
+
1060
+ # Custom CSS for Table Styling
1061
+ st.markdown("""
1062
+ <style>
1063
+ .report-container {
1064
+ max-height: 250px;
1065
+ overflow-y: auto;
1066
+ border-radius: 25px;
1067
+ text-align:center;
1068
+ border: 5px solid rgba(173, 216, 230, 0.4);
1069
+ padding: 10px;
1070
+ background: rgba(255, 255, 255, 0.05);
1071
+ box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
1072
+ width:480px;
1073
+ margin-left:100px;
1074
+ margin-top:-20px;
1075
+ }
1076
+ .report-container h4{
1077
+ font-family: "Times New Roman" !important; /* Elegant font for title */
1078
+ font-size: 1rem;
1079
+ margin-left: 5px;
1080
+ margin-bottom:1px;
1081
+ padding: 10px;
1082
+ color:#333;
1083
+
1084
+ }
1085
+ .report-table {
1086
+ width: 100%;
1087
+ border-collapse: collapse;
1088
+ font-family: 'Times New Roman', serif;
1089
+ text-align: center;
1090
+ }
1091
+ .report-table th {
1092
+ background: rgba(255, 255, 255, 0.05);
1093
+ font-size: 16px;
1094
+ padding: 10px;
1095
+ border-bottom: 2px solid #444;
1096
+ }
1097
+ .report-table td {
1098
+ font-size: 12px;
1099
+ padding: 10px;
1100
+ border-bottom: 1px solid #ddd;
1101
+ }
1102
+ </style>
1103
+ """, unsafe_allow_html=True)
1104
+ col1,col2 = st.columns([3,3])
1105
+ with col1:
1106
+ # Convert DataFrame to HTML Table
1107
+ report_html = report_df.to_html(index=False, classes="report-table", escape=False)
1108
+ st.markdown(f'<div class="report-container"><h4>classification report </h4>{report_html}</div>', unsafe_allow_html=True)
1109
+ # Generate Confusion Matrix
1110
+ # Generate Confusion Matrix
1111
+ cm = confusion_matrix(y_true, y_pred)
1112
+
1113
+ # Create Confusion Matrix Heatmap
1114
+ fig, ax = plt.subplots(figsize=(1, 1))
1115
+ fig.patch.set_alpha(0) # Make figure background transparent
1116
+
1117
+ # Seaborn Heatmap (Confusion Matrix)
1118
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
1119
+ xticklabels=class_names, yticklabels=class_names,
1120
+ linewidths=1, linecolor="black",
1121
+ cbar=False, square=True, alpha=0.9,
1122
+ annot_kws={"size": 5, "family": "Times New Roman"})
1123
+ # Change font for tick labels
1124
+ for text in ax.texts:
1125
+ text.set_bbox(dict(facecolor='none', edgecolor='none', alpha=0))
1126
+ plt.xticks(fontsize=4, family="Times New Roman") # X-axis font
1127
+ plt.yticks(fontsize=4, family="Times New Roman") # Y-axis font
1128
+ # Enhance Labels and Title
1129
+
1130
+ plt.title("Confusion Matrix", fontsize=5, family="Times New Roman",color="black", loc='center')
1131
+
1132
+ # Apply transparent background and double border (via Streamlit Markdown)
1133
+ st.markdown("""
1134
+ <style>
1135
+ div[data-testid="stImageContainer"] {
1136
+ max-height: 250px;
1137
+ overflow-y: auto;
1138
+ border-radius: 25px;
1139
+ text-align:center;
1140
+ border: 5px solid rgba(173, 216, 230, 0.4);
1141
+ padding: 10px;
1142
+ background: rgba(255, 255, 255, 0.05);
1143
+ box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
1144
+ width:480px !important;
1145
+ margin-left:100px;
1146
+ margin-top:-20px;
1147
+
1148
+ }
1149
+ div[data-testid="stImageContainer"] img{
1150
+ margin-top:-10px !important;
1151
+ width:400px !important;
1152
+ height:250px !important;
1153
+ }
1154
+ [class*="st-key-roc"] div[data-testid="stImageContainer"] {
1155
+ max-height: 250px;
1156
+ overflow-y: auto;
1157
+ border-radius: 25px;
1158
+ text-align:center;
1159
+ border: 5px solid rgba(173, 216, 230, 0.4);
1160
+ background: rgba(255, 255, 255, 0.05);
1161
+ box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
1162
+ width:480px;
1163
+ margin-left:-35px;
1164
+ margin-top:-15px;
1165
+ }
1166
+ [class*="st-key-roc"] div[data-testid="stImageContainer"] img{
1167
+ width:480px !important;
1168
+ height:200px !important;
1169
+ margin-top:-20px !important;
1170
+
1171
+ }
1172
+ [class*="st-key-precision"] div[data-testid="stImageContainer"] {
1173
+ max-height: 250px;
1174
+ overflow-y: auto;
1175
+ border-radius: 25px;
1176
+ text-align:center;
1177
+ border: 5px solid rgba(173, 216, 230, 0.4);
1178
+ background: rgba(255, 255, 255, 0.05);
1179
+ box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
1180
+ width:480px;
1181
+ margin-left:-35px;
1182
+ margin-top:-5px;
1183
+ }
1184
+ [class*="st-key-precision"] div[data-testid="stImageContainer"] img{
1185
+ width:480px !important;
1186
+ height:250px !important;
1187
+ margin-top:-20px !important;
1188
+
1189
+ }
1190
+ </style>
1191
+ """, unsafe_allow_html=True)
1192
+
1193
+ # Show Plot in Streamlit inside a styled container
1194
+ st.markdown('<div class="confusion-matrix-container">', unsafe_allow_html=True)
1195
+ st.pyplot(fig)
1196
+ st.markdown("</div>", unsafe_allow_html=True)
1197
+
1198
+ with col2:
1199
+ # Convert the lists to numpy arrays
1200
+ # y_true = np.concatenate(y_true)
1201
+ y_pred_probs = np.array(y_pred) # Make sure this is a 2D array: [batch_size, 2]
1202
+ y_true = np.array(y_true) # Ensure y_true is a numpy array
1203
+
1204
+ print(y_pred)
1205
+ print(y_true)
1206
+
1207
+ # Binarize the true labels for multi-class classification
1208
+ y_true_bin = label_binarize(y_true, classes=np.arange(len(class_names)))
1209
+
1210
+ # Initialize dictionaries for storing ROC curve data
1211
+ # Calculating ROC curve and AUC for each class
1212
+ fpr, tpr, roc_auc = {}, {}, {}
1213
+
1214
+ # Calculate ROC curve and AUC for the positive class (class 1)
1215
+ fpr[0], tpr[0], _ = roc_curve(y_true_bin, y_pred_probs) # Use 1D probabilities for class 1
1216
+ roc_auc[0] = auc(fpr[0], tpr[0]) # Calculate AUC for class 1
1217
+
1218
+ # Plotting the ROC curve for each class
1219
+ plt.figure(figsize=(11, 9))
1220
+
1221
+ # Plot ROC curve for the positive class (class 1)
1222
+ plt.plot(fpr[0], tpr[0], lw=2, label=f'Class 1 (AUC = {roc_auc[0]:.2f})')
1223
+
1224
+ # Plot random guess line (diagonal line)
1225
+ plt.plot([0, 1], [0, 1], color='navy', lw=5, linestyle='--')
1226
+
1227
+ # Labels and legend
1228
+ plt.xlim([0.0, 1.0])
1229
+ plt.ylim([0.0, 1.05])
1230
+ plt.xlabel('False Positive Rate', fontsize=28, family="Times New Roman")
1231
+ plt.ylabel('True Positive Rate', fontsize=28, family="Times New Roman")
1232
+ plt.title('ROC Curve (One-vs-Rest) for Each Class', fontsize=30, family="Times New Roman", color="black", loc='center', pad=3)
1233
+ plt.legend(loc='lower right', fontsize=18)
1234
+
1235
+ # Save the plot as an image file
1236
+ plt.savefig('roc_curve.png', transparent=True)
1237
+ plt.close()
1238
+
1239
+ # Display the ROC curve in Streamlit
1240
+ with st.container(key="roc"):
1241
+ st.image('roc_curve.png')
1242
+
1243
+ with st.container(key="precision"):
1244
+ # Compute Precision-Recall curve
1245
+ precision, recall, _ = precision_recall_curve(y_true_bin, y_pred_probs)
1246
+
1247
+ # Calculate AUC for Precision-Recall curve
1248
+ pr_auc = auc(recall, precision)
1249
+
1250
+ # Plot Precision-Recall curve
1251
+ plt.figure(figsize=(11, 9))
1252
+ plt.plot(recall, precision, lw=2, label=f'Precision-Recall curve (AUC = {pr_auc:.2f})')
1253
+
1254
+
1255
+ plt.xlabel('Recall', fontsize=28,family="Times New Roman")
1256
+ plt.ylabel('Precision', fontsize=28,family="Times New Roman")
1257
+ plt.title('Precision-Recall Curve for Each Class', fontsize=30, family="Times New Roman",color="black", loc='center',pad=3)
1258
+ plt.legend(loc='lower left', fontsize=18)
1259
+ plt.grid(True, linestyle='--', alpha=0.7)
1260
+ plt.savefig('precision_recall_curve.png', transparent=True)
1261
+ plt.close()
1262
+
1263
+ st.image('precision_recall_curve.png')
1264
+ if st.session_state.show_desc:
1265
+ # components.html(html_string) # JavaScript works
1266
+ # st.markdown(html_string, unsafe_allow_html=True)
1267
+ image_path = "image.jpg"
1268
+
1269
+ st.container()
1270
+ st.markdown(
1271
+ f"""
1272
+
1273
+ <div class="titles">
1274
+ <h1>Brain Tummor Classfication</br> Using Transfer learning</h1>
1275
+ <div> This web application utilizes transfer learning to classify kidney ultrasound images</br>
1276
+ into four categories: HEALTH and TUMOR and CYST Class.
1277
+ Built with Streamlit and powered by </br>a TensorFlow transfer learning
1278
+ model based on <strong>VGG16</strong>
1279
+ the app provides </br>a simple and efficient way for users
1280
+ to upload brain scans and receive instant predictions.</br> The model analyzes the image
1281
+ and classifies it based </br>on learned patterns, offering a confidence score for better interpretation.
1282
+ </div>
1283
+ </div>
1284
+ """,
1285
+ unsafe_allow_html=True,
1286
+ )
1287
+ uploaded_file = st.file_uploader(
1288
+ "Choose a file", type=["png", "jpg", "jpeg"], key="upload-btn"
1289
+ )
1290
+ if uploaded_file is not None:
1291
+ images = Image.open(uploaded_file)
1292
+ # Rewind file pointer to the beginning
1293
+ uploaded_file.seek(0)
1294
+
1295
+ file_content = uploaded_file.read() # Read file once
1296
+ # Convert to base64 for HTML display
1297
+ encoded_image = base64.b64encode(file_content).decode()
1298
+ # Read and process image
1299
+ pil_image = Image.open(uploaded_file).convert("RGB").resize((224, 224))
1300
+ img_array = np.array(pil_image)
1301
+
1302
+ class0, class1,prediction = predict_image(images)
1303
+ max_index = int(np.argmax(prediction[0]))
1304
+ print(f"max index:{max_index}")
1305
+ max_score = prediction[0][max_index]
1306
+
1307
+ predicted_class = np.argmax(prediction[0])
1308
+ print(f"predicted class is :{predicted_class}")
1309
+
1310
+ cams = generate_gradcam(pil_image, predicted_class)
1311
+ heatmap = cm.jet(cams)[..., :3]
1312
+ heatmap = (heatmap * 255).astype(np.uint8)
1313
+ overlayed_image = cv2.addWeighted(img_array, 0.6, heatmap, 0.4, 0)
1314
+
1315
+ # Convert to PIL
1316
+ overlayed_pil = Image.fromarray(overlayed_image)
1317
+ # Convert to base64
1318
+ orig_b64 = convert_image_to_base64(pil_image)
1319
+ overlay_b64 = convert_image_to_base64(overlayed_pil)
1320
+
1321
+ highlight_class = "highlight" # Special class for the highest confidence score
1322
+
1323
+ # Generate Grad-CAM
1324
+ #cam = generate_gradcam(pil_image, predicted_class)
1325
+
1326
+ # Create overlay
1327
+ #heatmap = cm.jet(cam)[..., :3]
1328
+ #heatmap = (heatmap * 255).astype(np.uint8)
1329
+ #overlayed_image = cv2.addWeighted(img_array, 0.6, heatmap, 0.4, 0)
1330
+
1331
+ # Convert to PIL
1332
+ #overlayed_pil = Image.fromarray(overlayed_image)
1333
+ # Convert to base64
1334
+ orig_b64 = convert_image_to_base64(pil_image)
1335
+ #overlay_b64 = convert_image_to_base64(overlayed_pil)
1336
+ content = f"""
1337
+ <div class="content-container">
1338
+ <!-- Title -->
1339
+ <!-- Recently Viewed Section -->
1340
+ <div class="content-container3">
1341
+ <img src="data:image/png;base64,{orig_b64}" alt="Uploaded Image">
1342
+ </div>
1343
+ <div class="content-container3">
1344
+ <img src="data:image/png;base64,{overlay_b64}" class="result-image">
1345
+ </div>
1346
+ <div class="content-container4 {'highlight' if max_index == 0 else ''}">
1347
+ <h3>{class_labels[0]}</h3>
1348
+ <p>T Score: {class0 :.2f}</p>
1349
+ </div>
1350
+ <div class="content-container5 {'highlight' if max_index == 1 else ''}">
1351
+ <h3> {class_labels[1]}</h3>
1352
+ <p>T Score: {class1 :.2f}</p>
1353
+ </div>
1354
+ """
1355
+
1356
+ # Close the gallery and content div
1357
+
1358
+ # Render the content
1359
+ placeholder = st.empty() # Create a placeholder
1360
+ placeholder.markdown(loading_html, unsafe_allow_html=True)
1361
+ time.sleep(5) # Wait for 5 seconds
1362
+ placeholder.empty()
1363
+ st.markdown(content, unsafe_allow_html=True)
1364
+ else:
1365
+ default_image_path = "image.jpg"
1366
+ with open(image_path, "rb") as image_file:
1367
+ encoded_image = base64.b64encode(image_file.read()).decode()
1368
+
1369
+ st.markdown(
1370
+ f"""
1371
+ <div class="content-container">
1372
+ <!-- Title -->
1373
+ <!-- Recently Viewed Section -->
1374
+ <div class="content-container3">
1375
+ <img src="data:image/png;base64,{encoded_image}" alt="Default Image">
1376
+ </div>
1377
+ </div>
1378
+
1379
+ """,
1380
+ unsafe_allow_html=True,
1381
+ )