Transfer-SN / app.py
danielle2003's picture
Update app.py
6227a23 verified
import streamlit as st
st.set_page_config(layout="wide")
import streamlit.components.v1 as components
import cv2
from PIL import Image
import base64
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from torch.utils.data import DataLoader
from PIL import Image
from io import BytesIO
from gradcam import GradCAM # Import your GradCAM class
from sklearn.metrics import classification_report,confusion_matrix, roc_curve, auc,precision_recall_curve, average_precision_score
from sklearn.preprocessing import label_binarize
import seaborn as sns
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
import torchvision.transforms as transforms
import io
import warnings
warnings.filterwarnings("ignore")
showWarningOnDirectExecution = False
# Path to your logo image
logo_path = "pytorch.png"
main_bg_ext = 'png'
# Read and encode the logo image
with open(logo_path, "rb") as image_file:
encoded_logo = base64.b64encode(image_file.read()).decode()
if "framework" not in st.session_state:
st.session_state.framework = "Tensorflow"
if "menu" not in st.session_state:
st.session_state.menu = "3"
if st.session_state.menu =="1":
st.session_state.show_summary = True
st.session_state.show_arch = False
st.session_state.show_desc = False
elif st.session_state.menu =="2":
st.session_state.show_arch = True
st.session_state.show_summary = False
st.session_state.show_desc = False
elif st.session_state.menu =="3":
st.session_state.show_arch = False
st.session_state.show_summary = False
st.session_state.show_desc = True
else:
st.session_state.show_desc = True
def encode_image(image_path):
with open(image_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode()
#**************************************************
# loading pytorch model
#********************************************
# Define the CustomVGG16 model
class CustomVGG16(nn.Module):
def __init__(self, num_classes=2):
super(CustomVGG16, self).__init__()
base_model = models.vgg16(pretrained=False)
self.features = base_model.features
self.avgpool = nn.AdaptiveAvgPool2d((2, 2))
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(512 * 2 * 2, 512)
self.bn1 = nn.BatchNorm1d(512)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(512, num_classes)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.bn1(x)
x = torch.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.softmax(x)
return x
# Load the model
model = CustomVGG16(num_classes=2)
# Load the state_dict (weights only)
model.load_state_dict(torch.load('brain_model.pth', map_location=torch.device('cpu')))
model.eval()#model.eval() # Set the model to evaluation mode
target_layer = model.features[-1] # Typically last convolutional layer
gradcam = GradCAM(model, target_layer)
def preprocess_image(image):
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # For pretrained models like VGG16
])
return preprocess(image).unsqueeze(0) # Add batch dimension
def generate_gradcam(image, target_class):
# Preprocess the image and convert it to a tensor
input_image = preprocess_image(image)
# Instantiate GradCAM
gradcam = GradCAM(model, target_layer)
# Generate the CAM
cam = gradcam.generate(input_image, target_class)
return cam
# Function to get layer information
def get_layers_data(model, prefix=""):
layers_data = []
for name, layer in model.named_children(): # Iterate over layers
full_name = f"{prefix}.{name}" if prefix else name # Track hierarchy
try:
shape = str(list(layer.parameters())[0].shape) # Get shape of the first param
except Exception:
shape = "N/A"
param_count = sum(p.numel() for p in layer.parameters()) # Count parameters
layers_data.append((full_name, layer.__class__.__name__, shape, f"{param_count:,}"))
# Recursively get layers inside this layer (for nested structures)
layers_data.extend(get_layers_data(layer, full_name))
return layers_data
def convert_image_to_base64(pil_image):
buffered = BytesIO()
pil_image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def predict_image(image):
# Preprocess the image to match the model input requirements
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Standard VGG16 normalization
])
image = transform(image).unsqueeze(0) # Add batch dimension
# Move image to the same device as the model (GPU or CPU)
image = image
# Set the model to evaluation mode
model.eval()
with torch.no_grad(): # Disable gradient calculation
outputs = model(image) # Forward pass
# Get predicted probabilities (softmax for multi-class)
if outputs.shape[1] == 1:
probs = torch.sigmoid(outputs) # Apply sigmoid activation for binary classification
prob_class_1 = probs[0].item() # Probability for class 1
prob_class_0 = 1 - prob_class_1 # Probability for class 0
# If the output has two units (binary classification with softmax)
else:
probs = torch.nn.functional.softmax(outputs, dim=1)
prob_class_0 = probs[0, 0].item()
prob_class_1 = probs[0, 1].item()
# Get the predicted class
print("Raw model output (logits):", outputs)
return prob_class_0, prob_class_1, probs
# /#*********************************************/
# LOADING TEST DATASET
# *************************************************
test_dir = "test"
BATCH_SIZE = 32
IMG_SIZE = (224, 224)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_dataset = datasets.ImageFolder(root='test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# One-hot encode labels using CategoryEncoding
class_names = test_dataset.classes
class_labels = class_names
# One-hot encode labels using CategoryEncoding
#num_classes = len(class_names)
#def one_hot_encode(image, label):
##label = tf.one_hot(label, num_classes)
#return image, label
#test_dataset = test_dataset.map(one_hot_encode)
#######################################################
# Custom CSS to style the logo above the sidebar
st.markdown(
f"""
<style>
/* Container for logo and text */
.logo-text-container {{
position: fixed;
top: 30px; /* Adjust vertical position */
left: 50px; /* Align with sidebar */
display: flex;
align-items: center;
gap: 15px;
justify-content: space-between;
width: 100%;
}}
/* Logo styling */
.logo-text-container img {{
width: 100px; /* Adjust logo size */
border-radius: 10px; /* Optional: round edges */
margin-top:10px;
margin-left:20px;
}}
/* Bold text styling */
.logo-text-container h1 {{
font-family: 'Times New Roman', serif;
font-size: 24px;
font-weight: bold;
margin:-right 100px;;
text-align: center;
align-items: center;
margin: 0 auto; /* Center the text */
flex-grow:1;
color: #FFD700; /* Golden color for text */
}}
/* Sidebar styling */
section[data-testid="stSidebar"][aria-expanded="true"] {{
margin-top: 100px !important; /* Space for the logo */
border-radius: 0 60px 0px 60px !important; /* Top-left and bottom-right corners */
width: 200px !important; /* Sidebar width */
background:none; /* Gradient background */
/* box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2); /* Shadow effect */
/* border: 1px solid #FFD700; /* Shiny golden border */
margin-bottom: 1px !important;
color:white !important;
}}
/* Style for the upload button */
[class*="st-key-upload-btn"] {{
position: absolute;
top: 50%; /* Position from the top of the inner circle */
left: 1%; /* Position horizontally at the center */
padding: 10px 20px;
color: red;
border: none;
border-radius: 20px;
cursor: pointer;
font-size: 35px !important;
width:30px;
height:20px;
}}
.upload-btn:hover {{
background-color: rgba(0, 123, 255, 1);
}}
div[data-testid="stFileUploader"] label > div > p {{
display:none;
color:white !important;
}}
section[data-testid="stFileUploaderDropzone"] {{
width:200px;
height: 60px;
background-color: white;
border-radius: 40px;
display: flex;
justify-content: center;
align-items: center;
margin-top:-10px;
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.3);
margin:20px;
background-color: rgba(255, 255, 255, 0.7); /* Transparent blue background */
color:white;
}}
div[data-testid="stFileUploaderDropzoneInstructions"] div > small{{
color:white !important;
display:none;
}}
div[data-testid="stFileUploaderDropzoneInstructions"] span{{
margin-left:65px;
color:orange;
}}
div[data-testid="stFileUploaderDropzoneInstructions"] div{{
display:none;
}}
section[data-testid="stFileUploaderDropzone"] button{{
display:none;
}}
div[data-testid="stMarkdownContainer"] p {{
font-family: "Times New Roman" !important; /* Elegant font for title */
color:white !important;
}}
.highlight {{
border: 4px solid lime;
font-weight: bold;
background: radial-gradient(circle, rgba(0,255,0,0.3) 0%, rgba(0,0,0,0) 70%);
box-shadow: 0px 0px 30px 10px rgba(0, 255, 0, 0.9),
0px 0px 60px 20px rgba(0, 255, 0, 0.6),
inset 0px 0px 15px rgba(0, 255, 0, 0.8);
transition: all 0.3s ease-in-out;
}}
.highlight:hover {{
transform: scale(1.05);
background: radial-gradient(circle, rgba(0,255,0,0.6) 0%, rgba(0,0,0,0) 80%);
box-shadow: 0px 0px 40px 15px rgba(0, 255, 0, 1),
0px 0px 70px 30px rgba(0, 255, 0, 0.7),
inset 0px 0px 20px rgba(0, 255, 0, 1);
}}
header[data-testid="stHeader"] {{
/* border-radius: 1px !important;*/
background: transparent !important; /* Gradient background */
/*box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2); /* Shadow effect */
/*: 3px solid #FFD700; /* Shiny golden border */
/*border-bottom:none !important;*/
margin-right: 100px !important;
margin-top: 32px !important;
z-index: 1 !important; /* Ensure it stays above other elements */
}}
div[data-testid="stDecoration"]{{
background-image:none;
}}
button[data-testid="stBaseButton-secondary"]{{
background:transparent;
border:none;
}}
div[data-testid="stApp"]{{
background:#161819;
height: 98vh; /* Full viewport height */
width: 98%;
border-radius: 40px !important;
margin-left:10px;
margin-right:10px;
margin-top:10px;
box-shadow: 0 4px 30px rgba(0, 0, 0, 0.5);
overflow: hidden;
}}
div[data-testid="stMarkdownContainer"] > p {{
font-family: "Times New Roman " !important; /* Font */
font-size: 11px !important; /* Font size */
margin:5px;
}}
[class*="st-key-content_"] {{
background: rgba(255, 255, 255, 0.9);
border-radius: 40px;
/* box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.1);*/
width: 83.7%;
margin-left: 75px;
/* margin-top: -70px;*/
margin-bottom: 10px;
margin-right:10px;
padding:0;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 1.5%; /* Adjust top offset */
left: 10%; /* Adjust left offset */
height: 98vh; /* Full viewport height */
}}
[class*="st-key-center-box"] {{
background-color: transparent;
border-radius: 60px;
width: 100%;
margin-top:30px;
top:20% !important; /* Adjust top offset */
left: 1%; /* Adjust left offset */
}}
[class*="st-key-side"] {{
background-color: transparent;
border-radius: 60px;
box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5);
width: 5%;
/* margin-top: 100px;*/
margin-bottom: 10px;
margin-right:10px;
padding:30px;
display: flex;
justify-content: center;
align-items: center;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 17%; /* Adjust top offset */
left: 16%; /* Adjust left offset */
height:50vh; /* Full viewport height */
}}
[class*="st-key-button_"] .stButton p > img {{
max-width: 100%;
vertical-align: top;
height:130px !important;
object-fit: cover;
padding: 10px;
width:250px !important;
border-radius:10px !important;
max-height: 2em !important;
}}
div.stButton > button {{
background: rgba(255, 255, 255, 0.2);
color: orange !important; /* White text */
font-family: "Times New Roman " !important; /* Font */
font-size: 18px !important; /* Font size */
font-weight: bold !important; /* Bold text */
padding: 1px 2px; /* Padding for buttons */
border: none; /* Remove border */
border-radius: 5px; /* Rounded corners */
box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.2); /* Shadow effect */
transition: all 0.3s ease-in-out; /* Smooth transition */
display: flex;
align-items: left;
justify-content: left;
margin-left:-15px ;
width:200px;
height:50px;
backdrop-filter: blur(10px);
z-index:1000;
text-align: left; /* Align text to the left */
padding-left: 50px;
}}
div.stButton > button p{{
color: white !important; /* White text */
}}
/* Hover effect */
div.stButton > button:hover {{
background: rgba(255, 255, 255, 0.2);
box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.4); /* Enhanced shadow on hover */
transform: scale(1.05); /* Slightly enlarge button */
transform: scale(1.1); /* Slight zoom on hover */
box-shadow: 0px 4px 12px rgba(255, 255, 255, 0.4); /* Glow effect */
}}
div.stButton > button:active {{
background: rgba(199, 107, 26, 0.5);
box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.4); /* Enhanced shadow on hover */
}}
div.stDownloadButton > button:active,
div.stDownloadButton > buttonfocus {{
background-color: transparent !important; /* or set it to the original background color */
outline: none; /* Remove the focus outline if you want */
}}
[class*="st-key-button_"] .stButton p > img {{
max-width: 100%;
vertical-align: top;
height:130px !important;
object-fit: cover;
padding: 10px;
width:250px !important;
border-radius:10px !important;
max-height: 2em !important;
}}
div.stDownloadButton > button > div > p {{
font-size:15px !important;
font-weight:bold;
}}
[class*="st-key-button_"] .stButton p{{
font-family: "Times New Roman " !important; /* Font */
font-size:100px !important;
height:150px !important;
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2);
font-weight: bold;
margin-top:5px;
margin-left:5px;
color:black;
border-radius:10px;
}}
[class*="st-key-button_"]:hover {{
}}
[class*="st-key-nav-"] .stButton p{{
font-family: "Times New Roman " !important; /* Font */
font-size:1rem !important;
font-weight: bold;
}}
[class*="st-key-nav-10"]{{
border: none; /* Remove border */
background: transparent !important;
backdrop-filter: blur(10px) !important;
border-radius:80px !important;
width:180px !important;
height:100px; !important;
margin-top:35px !important;
}}
[class*="st-key-nav-6"]{{
border: none; /* Remove border */
background: transparent !important;
border-radius:80px !important;
backdrop-filter: blur(10px) !important;
border-radius:80px !important;
width:180px !important;
margin-top:35px !important;
}}
[class*="st-key-nav-6"] {{
border: none; /* Remove border */
background: transparent !important;
border-radius:80px !important;
backdrop-filter: blur(10px) !important;
border-radius:80px !important;
width:190px !important;
margin-top:35px !important;
}}
[class*="st-key-nav-12"],[class*="st-key-blur_"]{{
border: none; /* Remove border */
background: transparent !important;
border-radius:80px !important;
backdrop-filter: blur(10px) !important;
border-radius:80px !important;
width:180px !important;
margin-top:35px !important;
}}
[class*="st-key-nav-8"]{{
border: none; /* Remove border */
background: transparent !important;
border-radius:80px !important;
backdrop-filter: blur(10px) !important;
border-radius:80px !important;
width:300px !important;
height:80px; !important;
margin-top:35px !important;
}}
[class*="st-key-nav-5"]{{
border: none; /* Remove border */
background: transparent !important;
border-radius:80px !important;
backdrop-filter: blur(10px) !important;
border-radius:80px !important;
width:200px !important;
height:80px; !important;
margin-top:35px !important;
}}
[class*="st-key-nav-"],[class*="st-key-blur_"] {{
background: rgba(255, 255, 255, 0.2);
color: black; /* White text */
font-family: "Times New Roman " !important; /* Font */
font-size: 18px !important; /* Font size */
font-weight: bold !important; /* Bold text */
padding: 10px 20px; /* Padding for buttons */
border: none; /* Remove border */
border-radius: 15px; /* Rounded corners */
box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.2); /* Shadow effect */
transition: all 0.3s ease-in-out; /* Smooth transition */
display: flex;
align-items: center;
justify-content: center;
margin: 10px 0;
width:170px;
height:60px;
backdrop-filter: blur(10px);
}}
/* Hover effect */
[class*="st-key-nav-"]:hover {{
background: rgba(255, 255, 255, 0.2);
box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.4); /* Enhanced shadow on hover */
transform: scale(1.05); /* Slightly enlarge button */
transform: scale(1.1); /* Slight zoom on hover */
box-shadow: 0px 4px 12px rgba(255, 255, 255, 0.4); /* Glow effect */
}}
/* Title styling */
.title {{
font-family: "Times New Roman" !important; /* Elegant font for title */
font-size: 1.2rem;
font-weight: bold;
margin-left: 37px;
margin-top:10px;
margin-bottom:-100px;
padding: 0;
color: #333; /* Neutral color for text */
}}
.content-container {{
background: rgba(255, 255, 255, 0.05);
backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
width: 28%;
margin-left: 150px;
/* margin-top: -60px;*/
margin-bottom: 10px;
margin-right:10px;
padding:0;
/* border-radius:0px 0px 15px 15px ;*/
border:1px solid transparent;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 10%; /* Adjust top offset */
left: 60%; /* Adjust left offset */
height: 89.5vh; /* Full viewport height */
}}
.content-container-principal img{{
margin-top:260px;
margin-left:30px;
}}
div[data-testid="stText"] {{
background-color: transparent;
backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
width: 132% !important;
background-color: rgba(173, 216, 230, 0.1); /* Light blue with 50% transparency */
margin-top: -36px;
margin-bottom: 10px;
margin-left:-220px !important;
padding:50px;
padding-bottom:20px;
padding-top:50px;
/* border-radius:0px 0px 15px 15px ;*/
border:1px solid transparent;
overflow-y: auto; /* Enable vertical scrolling for the content */
height: 85vh; !important; /* Full viewport height */
}}
.content-container2 {{
background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
width: 90%;
margin-left: 10px;
/* margin-top: -10px;*/
margin-bottom: 160px;
margin-right:10px;
padding:0;
border-radius:1px ;
border:1px solid transparent;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 3%; /* Adjust top offset */
left: 2.5%; /* Adjust left offset */
height: 78vh; /* Full viewport height */
}}
.content-container4 {{
background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%;
margin-left: 10px;
margin-bottom: 160px;
margin-right:10px;
padding:0;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 60%; /* Adjust top offset */
left: 2.5%; /* Adjust left offset */
height: 10vh; /* Full viewport height */
}}
.content-container4 h3 ,p {{
font-family: "Times New Roman" !important; /* Elegant font for title */
font-size: 1rem;
font-weight: bold;
text-align:center;
}}
.content-container5 h3 ,p {{
font-family: "Times New Roman" !important; /* Elegant font for title */
font-size: 1rem;
font-weight: bold;
text-align:center;
}}
.content-container6 h3 ,p {{
font-family: "Times New Roman" !important; /* Elegant font for title */
font-size: 1rem;
font-weight: bold;
text-align:center;
}}
.content-container7 h3 ,p {{
font-family: "Times New Roman" !important; /* Elegant font for title */
font-size: 1rem;
font-weight: bold;
text-align:center;
}}
.content-container5 {{
background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%;
margin-left: 180px;
margin-bottom: 130px;
margin-right:10px;
padding:0;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 60%; /* Adjust top offset */
left: 5.5%; /* Adjust left offset */
height: 10vh; /* Full viewport height */
}}
.content-container3 {{
background-color: rgba(216, 216, 230, 0.5); /* Light blue with 50% transparency */
backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
width: 92%;
margin-left: 10px;
/* margin-top: -10px;*/
margin-bottom: 160px;
margin-right:10px;
padding:0;
border: 10px solid white;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 3%; /* Adjust top offset */
left: 1.5%; /* Adjust left offset */
height: 40vh; /* Full viewport height */
}}
.content-container6 {{
background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%;
margin-left: 10px;
margin-bottom: 160px;
margin-right:10px;
padding:0;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 80%; /* Adjust top offset */
left: 2.5%; /* Adjust left offset */
height: 10vh; /* Full viewport height */
}}
.content-container7 {{
background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%;
margin-left: 180px;
margin-bottom: 130px;
margin-right:10px;
padding:0;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 80%; /* Adjust top offset */
left: 5.5%; /* Adjust left offset */
height: 10vh; /* Full viewport height */
}}
.content-container2 img {{
width:99%;
height:50%;
}}
.content-container3 img {{
width:100%;
height:100%;
}}
.side_box{{
width: 200px;
height: 180px;
background-color: #0175C2;
margin: 5px;
border-radius:20px;
left:-5%;
}}
.titles{{
margin-top:20px !important;
margin-left: -150px !important;
}}
/* Title styling */
.titles h1{{
/*font-family: "Times New Roman" !important; /* Elegant font for title */
font-size: 2.2rem;
/*font-weight: bold;*/
margin-left: 0px;
margin-top:80px;
margin-bottom:30px;
padding: 0;
color: black; /* Neutral color for text */
}}
.titles > div{{
font-family: "Times New Roman" !important; /* Elegant font for title */
font-size: 1.2rem;
margin-left: 200px;
margin-bottom:1px;
padding: 0;
color:black; /* Neutral color for text */
}}
</style>
<div class="logo-text-container">
<img src="data:image/png;base64,{encoded_logo}" alt="Logo">
</div>
""", unsafe_allow_html=True
)
loading_html = """
<style>
.loader {
border: 8px solid #f3f3f3;
border-top: 8px solid #0175C2; /* Blue color */
border-radius: 50%;
width: 50px;
height: 50px;
animation: spin 1s linear infinite;
margin: auto;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
</style>
<div class="loader"></div>
"""
# Sidebar content
st.markdown(
"""
<style>
.sidebar-desc {
font-family: "Times New Roman" !important; /* Elegant font for title */ font-size: 14px;
color: #333;
background-color: transparent;
padding: 15px;
border-radius: 20px;
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
width:200px !important;
margin-left:-15px;
margin-top:-50px;
height:70vh;
}
.sidebar-desc h3 {
font-family: "Times New Roman" !important; /* Elegant font for title */ font-size: 14px;
font-size: 18px;
color: #0175C2; /* Light Blue */
margin-bottom: 10px;
}
.sidebar-desc h4 {
font-size: 16px;
color: #444;
margin-bottom: 5px;
font-family: "Times New Roman" !important; /* Elegant font for title */ font-size: 14px;
}
.sidebar-desc ul {
list-style-type: square;
margin: 0;
padding-left: 20px;
}
.sidebar-desc ul li {
margin-bottom: 5px;
}
.sidebar-desc a {
color: #0175C2;
text-decoration: none;
}
.sidebar-desc a:hover {
text-decoration: underline;
}
</style>
""",
unsafe_allow_html=True,
)
# Use radio buttons for navigation
# Set the page to "Home"
page = "Home"
selected_img =""
st.session_state.page = "Home"
# Display content based on the selected page
if st.session_state.page == "Home":
# Sidebar buttons
with st.sidebar:
if st.button("📄 Model Summary"):
st.session_state.menu ="1" # Store state
st.rerun()
# Add your model description logic here
if st.button("📊 Model Results Analysis",key="header"):
st.session_state.menu ="2"
st.rerun()
# Add model analysis logic here
if st.button("🧪 Model Testing"):
st.session_state.menu ="3"
st.rerun()
table_style = """
<style>
table {
width: 100%;
border-collapse: collapse;
border-radius: 2px;
overflow: hidden;
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1);
background: rgba(255, 255, 255, 0.05);
backdrop-filter: blur(10px);
font-family: "Times New Roman", serif;
margin-left:100px;
margin-top:30px;
}
thead {
background: rgba(255, 255, 255, 0.2);
}
th {
padding: 12px;
text-align: left;
font-weight: bold;
backdrop-filter: blur(10px);
}
td {
padding: 12px;
border-bottom: 1px solid rgba(255, 255, 255, 0.1);
}
tr:hover {
background-color: rgba(255, 255, 255, 0.1);
}
tbody {
display: block;
max-height: 680px; /* Set the fixed height */
overflow-y: auto;
width: 100%;
}
thead, tbody tr {
display: table;
width: 100%;
table-layout: fixed;
}
</style>
"""
print(test_loader)
with st.container(key="content_1"):
print(type(model)) # Should print <class 'CustomVGG16'> and not OrderedDict
if st.session_state.show_summary:
# Load the model
layers_data = get_layers_data(model) # Get layer information
# Convert to HTML table
table_html = "<table><tr><th>Layer Name</th><th>Type</th><th>Output Shape</th><th>Param #</th></tr>"
for name, layer_type, shape, params in layers_data:
table_html += f"<tr><td>{name}</td><td>{layer_type}</td><td>{shape}</td><td>{params}</td></tr>"
table_html += "</table>"
st.markdown(table_style + table_html, unsafe_allow_html=True)
if st.session_state.show_arch:
model.eval()
# Initialize lists to store true labels and predicted labels
y_true = []
y_pred = []
for image, label in test_dataset: # test_dataset is an instance of ImageFolder or similar
image = image.unsqueeze(0) # Add batch dimension and move to device
label = label
with torch.no_grad():
output = model(image) # Get model output
_, predicted = torch.max(output, 1) # Get predicted class
y_true.append(label) # Append true label
y_pred.append(predicted.item()) # Append predicted label
# Generate the classification report
report_dict = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
# Convert to DataFrame for better readability
report_df = pd.DataFrame(report_dict).transpose().round(2)
accuracy = report_dict["accuracy"]
precision = report_df.loc["weighted avg", "precision"]
recall = report_df.loc["weighted avg", "recall"]
f1_score = report_df.loc["weighted avg", "f1-score"]
st.markdown("""
<style>
.kpi-container {
display: flex;
justify-content: space-between;
margin-bottom: 20px;
margin-left:100px;
margin-top:70px;
}
.kpi-card {
width: 23%;
padding: 15px;
text-align: center;
border-radius: 10px;
font-size: 22px;
font-weight: bold;
font-family: "Times New Roman " !important; /* Font */
color: #333;
background: rgba(255, 255, 255, 0.05);
box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
border: 5px solid rgba(173, 216, 230, 0.4);
}
</style>
<div class="kpi-container">
<div class="kpi-card">Precision<br>""" + f"{precision:.2f}" + """</div>
<div class="kpi-card">Recall<br>""" + f"{recall:.2f}" + """</div>
<div class="kpi-card">Accuracy<br>""" + f"{accuracy:.3f}" + """</div>
<div class="kpi-card">F1-Score<br>""" + f"{f1_score:.3f}" + """</div>
</div>
""", unsafe_allow_html=True)
# Remove last rows (accuracy/macro avg/weighted avg) and reset index
report_df = report_df.iloc[:-3].reset_index()
report_df.rename(columns={"index": "Class"}, inplace=True)
# Custom CSS for Table Styling
st.markdown("""
<style>
.report-container {
max-height: 250px;
overflow-y: auto;
border-radius: 25px;
text-align:center;
border: 5px solid rgba(173, 216, 230, 0.4);
padding: 10px;
background: rgba(255, 255, 255, 0.05);
box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
width:480px;
margin-left:100px;
margin-top:-20px;
}
.report-container h4{
font-family: "Times New Roman" !important; /* Elegant font for title */
font-size: 1rem;
margin-left: 5px;
margin-bottom:1px;
padding: 10px;
color:#333;
}
.report-table {
width: 100%;
border-collapse: collapse;
font-family: 'Times New Roman', serif;
text-align: center;
}
.report-table th {
background: rgba(255, 255, 255, 0.05);
font-size: 16px;
padding: 10px;
border-bottom: 2px solid #444;
}
.report-table td {
font-size: 12px;
padding: 10px;
border-bottom: 1px solid #ddd;
}
</style>
""", unsafe_allow_html=True)
col1,col2 = st.columns([3,3])
with col1:
# Convert DataFrame to HTML Table
report_html = report_df.to_html(index=False, classes="report-table", escape=False)
st.markdown(f'<div class="report-container"><h4>classification report </h4>{report_html}</div>', unsafe_allow_html=True)
# Generate Confusion Matrix
# Generate Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
# Create Confusion Matrix Heatmap
fig, ax = plt.subplots(figsize=(1, 1))
fig.patch.set_alpha(0) # Make figure background transparent
# Seaborn Heatmap (Confusion Matrix)
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
xticklabels=class_names, yticklabels=class_names,
linewidths=1, linecolor="black",
cbar=False, square=True, alpha=0.9,
annot_kws={"size": 5, "family": "Times New Roman"})
# Change font for tick labels
for text in ax.texts:
text.set_bbox(dict(facecolor='none', edgecolor='none', alpha=0))
plt.xticks(fontsize=4, family="Times New Roman") # X-axis font
plt.yticks(fontsize=4, family="Times New Roman") # Y-axis font
# Enhance Labels and Title
plt.title("Confusion Matrix", fontsize=5, family="Times New Roman",color="black", loc='center')
# Apply transparent background and double border (via Streamlit Markdown)
st.markdown("""
<style>
div[data-testid="stImageContainer"] {
max-height: 250px;
overflow-y: auto;
border-radius: 25px;
text-align:center;
border: 5px solid rgba(173, 216, 230, 0.4);
padding: 10px;
background: rgba(255, 255, 255, 0.05);
box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
width:480px !important;
margin-left:100px;
margin-top:-20px;
}
div[data-testid="stImageContainer"] img{
margin-top:-10px !important;
width:400px !important;
height:250px !important;
}
[class*="st-key-roc"] div[data-testid="stImageContainer"] {
max-height: 250px;
overflow-y: auto;
border-radius: 25px;
text-align:center;
border: 5px solid rgba(173, 216, 230, 0.4);
background: rgba(255, 255, 255, 0.05);
box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
width:480px;
margin-left:-35px;
margin-top:-15px;
}
[class*="st-key-roc"] div[data-testid="stImageContainer"] img{
width:480px !important;
height:200px !important;
margin-top:-20px !important;
}
[class*="st-key-precision"] div[data-testid="stImageContainer"] {
max-height: 250px;
overflow-y: auto;
border-radius: 25px;
text-align:center;
border: 5px solid rgba(173, 216, 230, 0.4);
background: rgba(255, 255, 255, 0.05);
box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
width:480px;
margin-left:-35px;
margin-top:-5px;
}
[class*="st-key-precision"] div[data-testid="stImageContainer"] img{
width:480px !important;
height:250px !important;
margin-top:-20px !important;
}
</style>
""", unsafe_allow_html=True)
# Show Plot in Streamlit inside a styled container
st.markdown('<div class="confusion-matrix-container">', unsafe_allow_html=True)
st.pyplot(fig)
st.markdown("</div>", unsafe_allow_html=True)
with col2:
# Convert the lists to numpy arrays
# y_true = np.concatenate(y_true)
y_pred_probs = np.array(y_pred) # Make sure this is a 2D array: [batch_size, 2]
y_true = np.array(y_true) # Ensure y_true is a numpy array
print(y_pred)
print(y_true)
# Binarize the true labels for multi-class classification
y_true_bin = label_binarize(y_true, classes=np.arange(len(class_names)))
# Initialize dictionaries for storing ROC curve data
# Calculating ROC curve and AUC for each class
fpr, tpr, roc_auc = {}, {}, {}
# Calculate ROC curve and AUC for the positive class (class 1)
fpr[0], tpr[0], _ = roc_curve(y_true_bin, y_pred_probs) # Use 1D probabilities for class 1
roc_auc[0] = auc(fpr[0], tpr[0]) # Calculate AUC for class 1
# Plotting the ROC curve for each class
plt.figure(figsize=(11, 9))
# Plot ROC curve for the positive class (class 1)
plt.plot(fpr[0], tpr[0], lw=2, label=f'Class 1 (AUC = {roc_auc[0]:.2f})')
# Plot random guess line (diagonal line)
plt.plot([0, 1], [0, 1], color='navy', lw=5, linestyle='--')
# Labels and legend
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=28, family="Times New Roman")
plt.ylabel('True Positive Rate', fontsize=28, family="Times New Roman")
plt.title('ROC Curve (One-vs-Rest) for Each Class', fontsize=30, family="Times New Roman", color="black", loc='center', pad=3)
plt.legend(loc='lower right', fontsize=18)
# Save the plot as an image file
plt.savefig('roc_curve.png', transparent=True)
plt.close()
# Display the ROC curve in Streamlit
with st.container(key="roc"):
st.image('roc_curve.png')
with st.container(key="precision"):
# Compute Precision-Recall curve
precision, recall, _ = precision_recall_curve(y_true_bin, y_pred_probs)
# Calculate AUC for Precision-Recall curve
pr_auc = auc(recall, precision)
# Plot Precision-Recall curve
plt.figure(figsize=(11, 9))
plt.plot(recall, precision, lw=2, label=f'Precision-Recall curve (AUC = {pr_auc:.2f})')
plt.xlabel('Recall', fontsize=28,family="Times New Roman")
plt.ylabel('Precision', fontsize=28,family="Times New Roman")
plt.title('Precision-Recall Curve for Each Class', fontsize=30, family="Times New Roman",color="black", loc='center',pad=3)
plt.legend(loc='lower left', fontsize=18)
plt.grid(True, linestyle='--', alpha=0.7)
plt.savefig('precision_recall_curve.png', transparent=True)
plt.close()
st.image('precision_recall_curve.png')
if st.session_state.show_desc:
# components.html(html_string) # JavaScript works
# st.markdown(html_string, unsafe_allow_html=True)
image_path = "new.jpg"
st.container()
st.markdown(
f"""
<div class="titles">
<h1>Brain Tummor Classfication</br> Using Transfer learning</h1>
<div> This web application utilizes transfer learning to classify kidney ultrasound images</br>
into two categories: HEALTH and TUMOR Class.
Built with Streamlit and powered by </br>a Pytorch transfer learning
model based on <strong>VGG16</strong>
the app provides </br>a simple and efficient way for users
to upload brain scans and receive instant predictions.</br> The model analyzes the image
and classifies it based </br>on learned patterns, offering a confidence score for better interpretation.
</div>
</div>
""",
unsafe_allow_html=True,
)
uploaded_file = st.file_uploader(
"Choose a file", type=["png", "jpg", "jpeg"], key="upload-btn"
)
if uploaded_file is not None:
images = Image.open(uploaded_file)
# Rewind file pointer to the beginning
uploaded_file.seek(0)
file_content = uploaded_file.read() # Read file once
# Convert to base64 for HTML display
encoded_image = base64.b64encode(file_content).decode()
# Read and process image
pil_image = Image.open(uploaded_file).convert("RGB").resize((224, 224))
img_array = np.array(pil_image)
class0, class1,prediction = predict_image(images)
max_index = int(np.argmax(prediction[0]))
print(f"max index:{max_index}")
max_score = prediction[0][max_index]
predicted_class = np.argmax(prediction[0])
print(f"predicted class is :{predicted_class}")
cams = generate_gradcam(pil_image, predicted_class)
heatmap = cm.jet(cams)[..., :3]
heatmap = (heatmap * 255).astype(np.uint8)
overlayed_image = cv2.addWeighted(img_array, 0.6, heatmap, 0.4, 0)
# Convert to PIL
overlayed_pil = Image.fromarray(overlayed_image)
# Convert to base64
orig_b64 = convert_image_to_base64(pil_image)
overlay_b64 = convert_image_to_base64(overlayed_pil)
highlight_class = "highlight" # Special class for the highest confidence score
# Generate Grad-CAM
#cam = generate_gradcam(pil_image, predicted_class)
# Create overlay
#heatmap = cm.jet(cam)[..., :3]
#heatmap = (heatmap * 255).astype(np.uint8)
#overlayed_image = cv2.addWeighted(img_array, 0.6, heatmap, 0.4, 0)
# Convert to PIL
#overlayed_pil = Image.fromarray(overlayed_image)
# Convert to base64
orig_b64 = convert_image_to_base64(pil_image)
#overlay_b64 = convert_image_to_base64(overlayed_pil)
content = f"""
<div class="content-container">
<!-- Title -->
<!-- Recently Viewed Section -->
<div class="content-container3">
<img src="data:image/png;base64,{orig_b64}" alt="Uploaded Image">
</div>
<div class="content-container3">
<img src="data:image/png;base64,{overlay_b64}" class="result-image">
</div>
<div class="content-container4 {'highlight' if max_index == 0 else ''}">
<h3>{class_labels[0]}</h3>
<p>T Score: {class0 :.2f}</p>
</div>
<div class="content-container5 {'highlight' if max_index == 1 else ''}">
<h3> {class_labels[1]}</h3>
<p>T Score: {class1 :.2f}</p>
</div>
"""
# Close the gallery and content div
# Render the content
placeholder = st.empty() # Create a placeholder
placeholder.markdown(loading_html, unsafe_allow_html=True)
time.sleep(5) # Wait for 5 seconds
placeholder.empty()
st.markdown(content, unsafe_allow_html=True)
else:
default_image_path = "new.jpg"
with open(image_path, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode()
st.markdown(
f"""
<div class="content-container">
<!-- Title -->
<!-- Recently Viewed Section -->
<div class="content-container3">
<img src="data:image/png;base64,{encoded_image}" alt="Default Image">
</div>
</div>
""",
unsafe_allow_html=True,
)