SN-CNN / app.py
danielle2003's picture
Update app.py
4456c08 verified
import cv2
import streamlit as st
st.set_page_config(layout="wide")
import streamlit.components.v1 as components
import time
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from PIL import Image
from tf_keras_vis.gradcam import Gradcam
from io import BytesIO
from sklearn.metrics import classification_report,confusion_matrix, roc_curve, auc,precision_recall_curve, average_precision_score
from sklearn.preprocessing import label_binarize
import seaborn as sns
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
import torchvision.transforms as transforms
import torch.nn.functional as F
from gradcam import GradCAM # Import your GradCAM class
if "model" not in st.session_state:
st.session_state.model = tf.keras.models.load_model(
"best_model.h5"
)
if "framework" not in st.session_state:
st.session_state.framework = "Tensorflow"
if "menu" not in st.session_state:
st.session_state.menu = "1"
if st.session_state.menu =="1":
st.session_state.show_summary = True
st.session_state.show_arch = False
st.session_state.show_desc = False
elif st.session_state.menu =="2":
st.session_state.show_arch = True
st.session_state.show_summary = False
st.session_state.show_desc = False
elif st.session_state.menu =="3":
st.session_state.show_arch = False
st.session_state.show_summary = False
st.session_state.show_desc = True
else:
st.session_state.show_desc = True
import base64
import os
import tf_keras_vis
# ****************************************/
# GRAD CAM
# *********************************************#
if st.session_state.framework == "TensorFlow":
gradcam = Gradcam(st.session_state.model, model_modifier=None, clone=False)
def generate_gradcam(pil_image, target_class):
# Convert PIL to array and preprocess
img_array = np.array(pil_image)
img_preprocessed = tf.keras.applications.vgg16.preprocess_input(img_array.copy())
img_tensor = tf.expand_dims(img_preprocessed, axis=0)
# Generate heatmap
loss = lambda output: tf.reduce_mean(output[:, target_class])
cam = gradcam(loss, img_tensor, penultimate_layer=-1)
# Process heatmap
cam = cam
if cam.ndim > 2:
cam = cam.squeeze()
cam = np.maximum(cam, 0)
cam = cv2.resize(cam, (224, 224))
cam = cam / cam.max() if cam.max() > 0 else cam
return cam
if st.session_state.framework == "PyTorch":
target_layer = st.session_state.model.conv3 # Typically last convolutional layer
#gradcam = GradCAM(st.session_state.model, target_layer)
def preprocess_image(image):
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
return preprocess(image).unsqueeze(0) # Add batch dimension
def generate_gradcams(image, target_class):
# Preprocess the image and convert it to a tensor
input_image = preprocess_image(image)
# Instantiate GradCAM
gradcampy = GradCAM(st.session_state.model, target_layer)
# Generate the CAM
cam = gradcampy.generate(input_image, target_class)
return cam
def convert_image_to_base64(pil_image):
buffered = BytesIO()
pil_image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
#-------------------------------------------------
#loading pytorch
class KidneyCNN(nn.Module):
def __init__(self, num_classes=4):
super(KidneyCNN, self).__init__()
# Convolutional layers
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
# Batch normalization layers
self.bn1 = nn.BatchNorm2d(32)
self.bn2 = nn.BatchNorm2d(64)
self.bn3 = nn.BatchNorm2d(128)
self.bn4 = nn.BatchNorm2d(256)
# Max pooling layers
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
# Fully connected layers
self.fc1 = nn.Linear(256 * 14 * 14, 512)
self.fc2 = nn.Linear(512, num_classes)
# Dropout for regularization
self.dropout = nn.Dropout(0.5)
def forward(self, x):
# Conv block 1
x = self.pool(F.relu(self.bn1(self.conv1(x))))
# Conv block 2
x = self.pool(F.relu(self.bn2(self.conv2(x))))
# Conv block 3
x = self.pool(F.relu(self.bn3(self.conv3(x))))
# Conv block 4
x = self.pool(F.relu(self.bn4(self.conv4(x))))
x = x.view(x.size(0), -1)
# Fully connected layers
x = self.dropout(F.relu(self.fc1(x)))
x = self.fc2(x)
return x
if st.session_state.framework =="PyTorch":
st.session_state.model = torch.load('kidney_model .pth', map_location=torch.device('cpu'))
st.session_state.model.eval()
print(type(st.session_state.model))
#*********************************************
# /#*********************************************/
# LOADING TEST DATASET
# *************************************************
if st.session_state.framework == "TensorFlow":
test_dir = "test"
BATCH_SIZE = 32
IMG_SIZE = (224, 224)
test_dataset = tf.keras.utils.image_dataset_from_directory(
test_dir, shuffle=False, batch_size=BATCH_SIZE, image_size=IMG_SIZE
)
class_names = test_dataset.class_names
def one_hot_encode(image, label):
label = tf.one_hot(label, num_classes)
return image, label
# One-hot encode labels using CategoryEncoding
class_labels = class_names
# One-hot encode labels using CategoryEncoding
# One-hot encode labels using CategoryEncoding
num_classes = len(class_names)
test_dataset = test_dataset.map(one_hot_encode)
elif st.session_state.framework == "PyTorch":
test_dir = "test"
BATCH_SIZE = 32
IMG_SIZE = (224, 224)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_dataset = datasets.ImageFolder(root='test', transform=transform)
class_names = test_dataset.classes
# One-hot encode labels using CategoryEncoding
class_labels = class_names
# One-hot encode labels using CategoryEncoding
# One-hot encode labels using CategoryEncoding
num_classes = len(class_names)
#######################################################
# --------------------------------------------------#
class_labels = ["Cyst", "Normal", "Stone", "Tumor"]
def load_tensorflow_model():
tf_model = tf.keras.models.load_model("best_model.h5")
return tf_model
if st.session_state.framework =="TensorFlow":
def predict_image(image):
time.sleep(2)
image = image.resize((224, 224))
image = np.expand_dims(image, axis=0)
predictions = st.session_state.model.predict(image)
return predictions
if st.session_state.framework == "PyTorch":
logo_path = "pytorch.png"
bg_color = "#FF5733" # For example, a warm red/orange
bg_color_iv = "orange" # For example, a warm red/orange
model = "TENSORFLOW"
def predict_image(image):
# Preprocess the image to match the model input requirements
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Standard VGG16 normalization
])
image = transform(image).unsqueeze(0) # Add batch dimension
# Move image to the same device as the model (GPU or CPU)
image = image
# Set the model to evaluation mode
st.session_state.model.eval()
with torch.no_grad(): # Disable gradient calculation
outputs = st.session_state.model(image) # Forward pass
# Get predicted probabilities (softmax for multi-class)
if outputs.shape[1] == 1:
probs = torch.sigmoid(outputs) # Apply sigmoid activation for binary classification
prob_class_1 = probs[0].item() # Probability for class 1
prob_class_0 = 1 - prob_class_1 # Probability for class 0
# If the output has two units (binary classification with softmax)
else:
probs = torch.nn.functional.softmax(outputs, dim=1)
prob_class_0 = probs[0, 0].item()
prob_class_1 = probs[0, 1].item()
# Get the predicted class
print("Raw model output (logits):", outputs)
return prob_class_0, prob_class_1, probs
else:
logo_path = "tensorflow.png"
bg_color = "orange" # For example, a warm red/orange
bg_color_iv = "#FF5733" # For example, a warm red/orange
model = "PYTORCH"
#/*******************loading pytorch summary
def get_layers_data(model, prefix=""):
layers_data = []
for name, layer in model.named_children(): # Iterate over layers
full_name = f"{prefix}.{name}" if prefix else name # Track hierarchy
try:
shape = str(list(layer.parameters())[0].shape) # Get shape of the first param
except Exception:
shape = "N/A"
param_count = sum(p.numel() for p in layer.parameters()) # Count parameters
layers_data.append((full_name, layer.__class__.__name__, shape, f"{param_count:,}"))
# Recursively get layers inside this layer (for nested structures)
layers_data.extend(get_layers_data(layer, full_name))
return layers_data
###########################################
main_bg_ext = "png"
main_bg = "bg1.jpg"
# Read and encode the logo image
with open(logo_path, "rb") as image_file:
encoded_logo = base64.b64encode(image_file.read()).decode()
# Custom CSS to style the logo above the sidebar
st.markdown(
f"""
<style>
/* Container for logo and text */
.logo-text-container {{
position: fixed;
top: 20px; /* Adjust vertical position */
left: 30px; /* Align with sidebar */
display: flex;
align-items: center;
gap: 5px;
width: 70%;
z-index:1000;
}}
/* Logo styling */
.logo-text-container img {{
width: 50px; /* Adjust logo size */
border-radius: 10px; /* Optional: round edges */
margin-top:-10px;
margin-left:-5px;
}}
/* Bold text styling */
.logo-text-container h1 {{
font-family: Nunito;
color: #0175C2;
font-size: 28px;
font-weight: bold;
margin-right :100px;
padding:0px;
}}
.logo-text-container i{{
font-family: Nunito;
color: {bg_color};
font-size: 15px;
margin-right :10px;
padding:0px;
margin-left:-18.5%;
margin-top:1%;
}}
/* Sidebar styling */
section[data-testid="stSidebar"][aria-expanded="true"] {{
margin-top: 100px !important; /* Space for the logo */
border-radius: 0 60px 0px 60px !important; /* Top-left and bottom-right corners */
width: 200px !important; /* Sidebar width */
background:none; /* Gradient background */
/* box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2); /* Shadow effect */
/* border: 1px solid #FFD700; /* Shiny golden border */
margin-bottom: 1px !important;
color:white !important;
}}
[class*="st-key-header"]{{
}}
header[data-testid="stHeader"] {{
/*background: transparent !important;*/
background: rgba(255, 255, 255, 0.05);
backdrop-filter: blur(10px);
/*margin-right: 10px !important;*/
margin-top: 0.5px !important;
z-index: 1 !important;
color: orange; /* White text */
font-family: "Times New Roman " !important; /* Font */
font-size: 18px !important; /* Font size */
font-weight: bold !important; /* Bold text */
padding: 10px 20px; /* Padding for buttons */
border: none; /* Remove border */
border-radius: 1px; /* Rounded corners */
box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.2); /* Shadow effect */
transition: all 0.3s ease-in-out; /* Smooth transition */
align-items: left;
justify-content: center;
/*margin: 10px 0;*/
width:100%;
height:80px;
backdrop-filter: blur(10px);
border: 2px solid rgba(255, 255, 255, 0.4); /* Light border */
}}
div[data-testid="stDecoration"]{{
background-image:none;
}}
div[data-testid="stApp"]{{
/*background: grey;*/
background: rgba(255, 255, 255, 0.5); /* Semi-transparent white background */
height: 100vh; /* Full viewport height */
width: 99.5%;
border-radius: 2px !important;
margin-left:5px;
margin-right:5px;
margin-top:0px;
/* box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.2); /* Shadow effect */
background: url(data:image/{main_bg_ext};base64,{base64.b64encode(open(main_bg, "rb").read()).decode()});
background-size: cover; /* Ensure the image covers the full page */
background-position: center;
overflow: hidden;
}}
.content-container {{
background: rgba(255, 255, 255, 0.05);
backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
width: 28%;
margin-left: 150px;
/* margin-top: -60px;*/
margin-bottom: 10px;
margin-right:10px;
padding:0;
/* border-radius:0px 0px 15px 15px ;*/
border:1px solid transparent;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 10%; /* Adjust top offset */
left: 60%; /* Adjust left offset */
height: 89.5vh; /* Full viewport height */
}}
.content-container-principal img{{
margin-top:260px;
margin-left:30px;
}}
.content-container-principal
{{
background-color: rgba(173, 216, 230, 0.5); /* Light blue with 50% transparency */
backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
width: 20%;
/* margin-top: -60px;*/
margin-bottom: 10px;
margin-right:10px;
margin:10px;
/* border-radius:0px 0px 15px 15px ;*/
border:1px solid transparent;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 7%; /* Adjust top offset */
/*left: 2%; Adjust left offset */
height: 84vh; /* Full viewport height */
}}
.content-container-principal-in
{{
background-color: rgba(173, 216, 230, 0.1); /* Light blue with 50% transparency */
backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
width: 100%;
/* margin-top: -60px;*/
margin:1px;
/* border-radius:0px 0px 15px 15px ;*/
border:1px solid transparent;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
height: 100.5vh; /* Full viewport height */
left:0%;
top:5%;
}}
div[data-testid="stText"] {{
background-color: transparent;
backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
width: 132% !important;
background-color: rgba(173, 216, 230, 0.1); /* Light blue with 50% transparency */
margin-top: -36px;
margin-bottom: 10px;
margin-left:-220px !important;
padding:50px;
padding-bottom:20px;
padding-top:50px;
/* border-radius:0px 0px 15px 15px ;*/
border:1px solid transparent;
overflow-y: auto; /* Enable vertical scrolling for the content */
height: 85vh; !important; /* Full viewport height */
}}
.content-container2 {{
background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
width: 90%;
margin-left: 10px;
/* margin-top: -10px;*/
margin-bottom: 160px;
margin-right:10px;
padding:0;
border-radius:1px ;
border:1px solid transparent;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 3%; /* Adjust top offset */
left: 2.5%; /* Adjust left offset */
height: 78vh; /* Full viewport height */
}}
.content-container4 {{
background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%;
margin-left: 10px;
margin-bottom: 160px;
margin-right:10px;
padding:0;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 60%; /* Adjust top offset */
left: 2.5%; /* Adjust left offset */
height: 10vh; /* Full viewport height */
}}
.content-container4 h3 ,p {{
font-family: "Times New Roman" !important; /* Elegant font for title */
font-size: 1rem;
font-weight: bold;
text-align:center;
}}
.content-container5 h3 ,p {{
font-family: "Times New Roman" !important; /* Elegant font for title */
font-size: 1rem;
font-weight: bold;
text-align:center;
}}
.content-container6 h3 ,p {{
font-family: "Times New Roman" !important; /* Elegant font for title */
font-size: 1rem;
font-weight: bold;
text-align:center;
}}
.content-container7 h3 ,p {{
font-family: "Times New Roman" !important; /* Elegant font for title */
font-size: 1rem;
font-weight: bold;
text-align:center;
}}
.content-container5 {{
background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%;
margin-left: 180px;
margin-bottom: 130px;
margin-right:10px;
padding:0;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 60%; /* Adjust top offset */
left: 5.5%; /* Adjust left offset */
height: 10vh; /* Full viewport height */
}}
.content-container3 {{
background-color: rgba(216, 216, 230, 0.5); /* Light blue with 50% transparency */
backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px;
width: 92%;
margin-left: 10px;
/* margin-top: -10px;*/
margin-bottom: 160px;
margin-right:10px;
padding:0;
border: 10px solid white;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 3%; /* Adjust top offset */
left: 1.5%; /* Adjust left offset */
height: 40vh; /* Full viewport height */
}}
.content-container6 {{
background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%;
margin-left: 10px;
margin-bottom: 160px;
margin-right:10px;
padding:0;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 80%; /* Adjust top offset */
left: 2.5%; /* Adjust left offset */
height: 10vh; /* Full viewport height */
}}
.content-container7 {{
background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */
backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%;
margin-left: 180px;
margin-bottom: 130px;
margin-right:10px;
padding:0;
overflow-y: auto; /* Enable vertical scrolling for the content */
position: fixed; /* Fix the position of the container */
top: 80%; /* Adjust top offset */
left: 5.5%; /* Adjust left offset */
height: 10vh; /* Full viewport height */
}}
.content-container2 img {{
width:99%;
height:50%;
}}
.content-container3 img {{
width:100%;
height:100%;
}}
div.stButton > button {{
background: rgba(255, 255, 255, 0.2);
color: orange !important; /* White text */
font-family: "Times New Roman " !important; /* Font */
font-size: 18px !important; /* Font size */
font-weight: bold !important; /* Bold text */
padding: 1px 2px; /* Padding for buttons */
border: none; /* Remove border */
border-radius: 5px; /* Rounded corners */
box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.2); /* Shadow effect */
transition: all 0.3s ease-in-out; /* Smooth transition */
display: flex;
align-items: left;
justify-content: left;
margin-left:-50px ;
width:250px;
height:50px;
backdrop-filter: blur(10px);
z-index:1000;
text-align: left; /* Align text to the left */
padding-left: 50px;
}}
div.stButton > button p{{
color: {bg_color} !important; /* White text */
}}
/* Hover effect */
div.stButton > button:hover {{
background: rgba(255, 255, 255, 0.2);
box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.4); /* Enhanced shadow on hover */
transform: scale(1.05); /* Slightly enlarge button */
transform: scale(1.1); /* Slight zoom on hover */
box-shadow: 0px 4px 12px rgba(255, 255, 255, 0.4); /* Glow effect */
}}
div.stButton > button:active {{
background: rgba(199, 107, 26, 0.5);
box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.4); /* Enhanced shadow on hover */
}}
.titles{{
margin-top:20px !important;
margin-left: -150px !important;
}}
/* Title styling */
.titles h1{{
/*font-family: "Times New Roman" !important; /* Elegant font for title */
font-size: 1.9rem;
/*font-weight: bold;*/
margin-left: 5px;
/* margin-top:-50px;*/
margin-bottom:50px;
padding: 0;
color: black; /* Neutral color for text */
}}
.titles > div{{
font-family: "Times New Roman" !important; /* Elegant font for title */
font-size: 1.01rem;
margin-left: -50px;
margin-bottom:1px;
padding: 0;
color:black; /* Neutral color for text */
}}
/* Recently viewed section */
.recently-viewed {{
display: flex;
align-items: center;
justify-content: flex-start; /* Align items to the extreme left */
margin-bottom: 10px;
margin-top: 20px;
gap: 10px; /* Add spacing between the elements */
padding-left: 20px; /* Add some padding if needed */
margin-left:35px;
height:100px;
}}
/* Style for the upload button */
[class*="st-key-upload-btn"] {{
position: absolute;
top: 100%; /* Position from the top of the inner circle */
left: -26%; /* Position horizontally at the center */
padding: 10px 20px;
color: red;
border: none;
border-radius: 20px;
cursor: pointer;
font-size: 35px !important;
width:30px;
height:20px;
}}
.upload-btn:hover {{
background-color: rgba(0, 123, 255, 1);
}}
div[data-testid="stFileUploader"] label > div > p {{
display:none;
color:white !important;
}}
section[data-testid="stFileUploaderDropzone"] {{
width:200px;
height: 60px;
background-color: white;
border-radius: 40px;
display: flex;
justify-content: center;
align-items: center;
margin-top:-10px;
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.3);
margin:20px;
background-color: rgba(255, 255, 255, 0.7); /* Transparent blue background */
color:white;
}}
div[data-testid="stFileUploaderDropzoneInstructions"] div > small{{
color:white !important;
display:none;
}}
div[data-testid="stFileUploaderDropzoneInstructions"] span{{
margin-left:65px;
color:{bg_color};
}}
div[data-testid="stFileUploaderDropzoneInstructions"] div{{
display:none;
}}
section[data-testid="stFileUploaderDropzone"] button{{
display:none;
}}
div[data-testid="stMarkdownContainer"] p {{
font-family: "Times New Roman" !important; /* Elegant font for title */
color:white !important;
}}
.highlight {{
border: 4px solid lime;
font-weight: bold;
background: radial-gradient(circle, rgba(0,255,0,0.3) 0%, rgba(0,0,0,0) 70%);
box-shadow: 0px 0px 30px 10px rgba(0, 255, 0, 0.9),
0px 0px 60px 20px rgba(0, 255, 0, 0.6),
inset 0px 0px 15px rgba(0, 255, 0, 0.8);
transition: all 0.3s ease-in-out;
}}
.highlight:hover {{
transform: scale(1.05);
background: radial-gradient(circle, rgba(0,255,0,0.6) 0%, rgba(0,0,0,0) 80%);
box-shadow: 0px 0px 40px 15px rgba(0, 255, 0, 1),
0px 0px 70px 30px rgba(0, 255, 0, 0.7),
inset 0px 0px 20px rgba(0, 255, 0, 1);
}}
.stCheckbox > label > div{{
width:303px !important;
height:3rem;
margin-top:270px;
margin-left:-72px;
border-radius:1px !important;
}}
.st-b1 {{
width:1.75rem;
height:1.75rem;
display:none;
}}
.stCheckbox > label > div:after {{
content: "SWITCH TO {model} MODEL";
display: block;
font-family: "Times New Roman", serif;
margin-top: 0.5em;
margin-left:20px;
font-weight:bold;
}}
.st-bj{{
display:none;
}}
.stCheckbox label{{
height:0px;
}}
.stCheckbox > label > div {{
background:{bg_color_iv} !important;
}}
</style>
<div class="logo-text-container">
<img src="data:image/png;base64,{encoded_logo}" alt="Logo">
<h1>KidneyScan AI<br>
</h1>
<i>Empowering Early Diagnosis with AI</ai>
</div>
""",
unsafe_allow_html=True,
)
loading_html = """
<style>
.loader {
border: 8px solid #f3f3f3;
border-top: 8px solid #0175C2; /* Blue color */
border-radius: 50%;
width: 50px;
height: 50px;
animation: spin 1s linear infinite;
margin: auto;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
</style>
<div class="loader"></div>
"""
# Sidebar content
# Use radio buttons for navigation
page = "pome"
# Sidebar buttons
# Display content based on the selected page
# Define the page content dynamically
if page == "Home":
# components.html(html_string) # JavaScript works
# st.markdown(html_string, unsafe_allow_html=True)
image_path = "image.jpg"
st.container()
st.markdown(
f"""
<div class="titles">
<h1>Kidney Disease Classfication</br> Using Transfer learning</h1>
<div> This web application utilizes deep learning to classify kidney ultrasound images</br>
into four categories: Normal, Cyst, Tumor, and Stone Class.
Built with Streamlit and powered by </br>a TensorFlow transfer learning
model based on <strong>VGG16</strong>
the app provides a simple and efficient way for users </br>
to upload kidney scans and receive instant predictions. The model analyzes the image
and classifies it based </br>on learned patterns, offering a confidence score for better interpretation.
</div>
</div>
""",
unsafe_allow_html=True,
)
uploaded_file = st.file_uploader(
"Choose a file", type=["png", "jpg", "jpeg"], key="upload-btn"
)
if uploaded_file is not None:
images = Image.open(uploaded_file)
# Rewind file pointer to the beginning
uploaded_file.seek(0)
file_content = uploaded_file.read() # Read file once
# Convert to base64 for HTML display
encoded_image = base64.b64encode(file_content).decode()
# Read and process image
pil_image = Image.open(uploaded_file).convert("RGB").resize((224, 224))
img_array = np.array(pil_image)
prediction = predict_image(images)
max_index = int(np.argmax(prediction[0]))
print(f"max index:{max_index}")
max_score = prediction[0][max_index]
predicted_class = np.argmax(prediction[0])
highlight_class = "highlight" # Special class for the highest confidence score
# Generate Grad-CAM
cam = generate_gradcam(pil_image, predicted_class)
# Create overlay
heatmap = cm.jet(cam)[..., :3]
heatmap = (heatmap * 255).astype(np.uint8)
overlayed_image = cv2.addWeighted(img_array, 0.6, heatmap, 0.4, 0)
# Convert to PIL
overlayed_pil = Image.fromarray(overlayed_image)
# Convert to base64
orig_b64 = convert_image_to_base64(pil_image)
overlay_b64 = convert_image_to_base64(overlayed_pil)
content = f"""
<div class="content-container">
<!-- Title -->
<!-- Recently Viewed Section -->
<div class="content-container2">
<div class="content-container3">
<img src="data:image/png;base64,{orig_b64}" alt="Uploaded Image">
</div>
<div class="content-container3">
<img src="data:image/png;base64,{overlay_b64}" class="result-image">
</div>
<div class="content-container4 {'highlight' if max_index == 0 else ''}">
<h3>{class_labels[0]}</h3>
<p>T Score: {prediction[0][0]:.2f}</p>
</div>
<div class="content-container5 {'highlight' if max_index == 1 else ''}">
<h3> {class_labels[1]}</h3>
<p>T Score: {prediction[0][1]:.2f}</p>
</div>
<div class="content-container6 {'highlight' if max_index == 2 else ''}">
<h3> {class_labels[2]}</h3>
<p>T Score: {prediction[0][2]:.2f}</p>
</div>
<div class="content-container7 {'highlight' if max_index == 3 else ''}">
<h3>{class_labels[3]}</h3>
<p>T Score: {prediction[0][3]:.2f}</p>
</div>
"""
# Close the gallery and content div
# Render the content
placeholder = st.empty() # Create a placeholder
placeholder.markdown(loading_html, unsafe_allow_html=True)
time.sleep(5) # Wait for 5 seconds
placeholder.empty()
st.markdown(content, unsafe_allow_html=True)
else:
default_image_path = "image.jpg"
with open(image_path, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode()
st.markdown(
f"""
<div class="content-container">
<!-- Title -->
<!-- Recently Viewed Section -->
<div class="content-container2">
<div class="content-container3">
<img src="data:image/png;base64,{encoded_image}" alt="Default Image">
</div>
</div>
""",
unsafe_allow_html=True,
)
if page == "pome":
gif_path = "bg3.gif"
with open(gif_path, "rb") as image_file:
encode_image = base64.b64encode(image_file.read()).decode()
st.markdown(
f"""
<div class="content-container-principal-in">
<div class="content-container-principal">
<img src="data:image/png;base64,{encode_image}" alt="Default Image">
</div>
</div>
""",
unsafe_allow_html=True,
)
col1, col2 = st.columns([1, 2]) # Adjust column widths
with col1:
if st.button("📄 Model Summary"):
st.session_state.menu ="1" # Store state
st.rerun()
# Add your model description logic here
if st.button("📊 Model Results Analysis",key="header"):
st.session_state.menu ="2"
st.rerun()
# Add model analysis logic here
if st.button("🧪 Model Testing"):
st.session_state.menu ="3"
st.rerun()
# Toggle switch UI
def framework_toggle():
toggle = st.toggle("Enable PyTorch", value=(st.session_state.framework == "PyTorch"))
if toggle and st.session_state.framework != "PyTorch":
st.session_state.framework = "PyTorch"
st.session_state.model = torch.load('kidney_model .pth', map_location=torch.device('cpu'))
st.rerun()
elif not toggle and st.session_state.framework != "TensorFlow":
st.session_state.framework = "TensorFlow"
st.session_state.model = tf.keras.models.load_model(
"best_model.h5"
)
st.rerun()
print(st.session_state.framework)
framework_toggle()
# Custom CSS for table styling
table_style = """
<style>
table {
width: 110%;
border-collapse: collapse;
border-radius: 2px;
overflow: hidden;
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.4);
background: rgba(255, 255, 255, 0.05);
backdrop-filter: blur(10px);
font-family: "Times New Roman", serif;
margin-left:-100px;
margin-top:10px;
}
thead {
background: rgba(255, 255, 255, 0.2);
}
th {
padding: 12px;
text-align: left;
font-weight: bold;
backdrop-filter: blur(10px);
}
td {
padding: 12px;
border-bottom: 1px solid rgba(255, 255, 255, 0.1);
}
tr:hover {
background-color: rgba(255, 255, 255, 0.1);
}
tbody {
display: block;
max-height: 580px; /* Set the fixed height */
overflow-y: auto;
width: 100%;
}
thead, tbody tr {
display: table;
width: 100%;
table-layout: fixed;
}
</style>
"""
with col2:
if st.session_state.show_summary:
layers_data = []
print(st.session_state)
if st.session_state.framework == "TensorFlow":
for layer in st.session_state.model.layers:
try:
shape = {layer.output.shape}
except Exception:
shape = "N/A"
if isinstance(shape, tuple):
shape = str(shape)
elif isinstance(shape, list):
shape = ", ".join(str(s) for s in shape)
elif shape is None:
shape = "N/A"
param_count = f"{layer.count_params():,}"
layers_data.append(
(layer.name, layer.__class__.__name__, shape, param_count)
)
print(layers_data)
elif st.session_state.framework == "PyTorch":
layers_data = get_layers_data(st.session_state.model) # Get layer information
# Convert to HTML table
table_html = "<table><tr><th>Layer Name</th><th>Type</th><th>Output Shape</th><th>Param #</th></tr>"
for name, layer_type, shape, params in layers_data:
table_html += f"<tr><td>{name}</td><td>{layer_type}</td><td>{shape}</td><td>{params}</td></tr>"
table_html += "</table>"
# Render table with custom styling
st.markdown(table_style + table_html, unsafe_allow_html=True)
if st.session_state.show_arch:
if st.session_state.framework == "TensorFlow":
y_true = np.concatenate([y.numpy() for _, y in test_dataset])
# Get model predictions
y_pred_probs = st.session_state.model.predict(test_dataset)
y_pred = np.argmax(y_pred_probs, axis=1)
# Convert one-hot true labels to class indices
y_true = np.argmax(y_true, axis=1)
# Class names (modify for your dataset)
class_names = ["Cyst", "Normal", "Stone", "Tumor"]
# Generate classification report as a dictionary
report_dict = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
# Convert to DataFrame
report_df = pd.DataFrame(report_dict).transpose().round(2)
accuracy = report_dict["accuracy"]
precision = report_df.loc["weighted avg", "precision"]
recall = report_df.loc["weighted avg", "recall"]
f1_score = report_df.loc["weighted avg", "f1-score"]
elif st.session_state.framework == "PyTorch":
y_true = []
y_pred = []
for image, label in test_dataset: # test_dataset is an instance of ImageFolder or similar
image = image.unsqueeze(0) # Add batch dimension and move to device
label = label
with torch.no_grad():
output = st.session_state.model(image) # Get model output
_, predicted = torch.max(output, 1) # Get predicted class
y_true.append(label) # Append true label
y_pred.append(predicted.item()) # Append predicted label
# Generate the classification report
report_dict = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
# Convert to DataFrame for better readability
report_df = pd.DataFrame(report_dict).transpose().round(2)
accuracy = report_dict["accuracy"]
precision = report_df.loc["weighted avg", "precision"]
recall = report_df.loc["weighted avg", "recall"]
f1_score = report_df.loc["weighted avg", "f1-score"]
st.markdown("""
<style>
.kpi-container {
display: flex;
justify-content: space-between;
margin-bottom: 20px;
margin-left:-80px;
margin-top:-30px;
}
.kpi-card {
width: 23%;
padding: 15px;
text-align: center;
border-radius: 10px;
font-size: 22px;
font-weight: bold;
font-family: "Times New Roman " !important; /* Font */
color: #333;
background: rgba(255, 255, 255, 0.05);
box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
border: 5px solid rgba(173, 216, 230, 0.4);
}
</style>
<div class="kpi-container">
<div class="kpi-card">Precision<br>""" + f"{precision:.2f}" + """</div>
<div class="kpi-card">Recall<br>""" + f"{recall:.2f}" + """</div>
<div class="kpi-card">Accuracy<br>""" + f"{accuracy:.2f}" + """</div>
<div class="kpi-card">F1-Score<br>""" + f"{f1_score:.2f}" + """</div>
</div>
""", unsafe_allow_html=True)
# Remove last rows (accuracy/macro avg/weighted avg) and reset index
report_df = report_df.iloc[:-3].reset_index()
report_df.rename(columns={"index": "Class"}, inplace=True)
# Custom CSS for Table Styling
st.markdown("""
<style>
.report-container {
max-height: 250px;
overflow-y: auto;
border-radius: 25px;
text-align:center;
border: 5px solid rgba(173, 216, 230, 0.4);
padding: 10px;
background: rgba(255, 255, 255, 0.05);
box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
width:480px;
margin-left:-80px;
margin-top:-20px;
}
.report-container h4{
font-family: "Times New Roman" !important; /* Elegant font for title */
font-size: 1rem;
margin-left: 5px;
margin-bottom:1px;
padding: 10px;
color:#333;
}
.report-table {
width: 100%;
border-collapse: collapse;
font-family: 'Times New Roman', serif;
text-align: center;
}
.report-table th {
background: rgba(255, 255, 255, 0.05);
font-size: 16px;
padding: 10px;
border-bottom: 2px solid #444;
}
.report-table td {
font-size: 12px;
padding: 10px;
border-bottom: 1px solid #ddd;
}
</style>
""", unsafe_allow_html=True)
col1,col2 = st.columns([3,3])
with col1:
# Convert DataFrame to HTML Table
report_html = report_df.to_html(index=False, classes="report-table", escape=False)
st.markdown(f'<div class="report-container"><h4>classification report </h4>{report_html}</div>', unsafe_allow_html=True)
# Generate Confusion Matrix
# Generate Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
# Create Confusion Matrix Heatmap
fig, ax = plt.subplots(figsize=(1, 1))
fig.patch.set_alpha(0) # Make figure background transparent
# Seaborn Heatmap (Confusion Matrix)
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
xticklabels=class_names, yticklabels=class_names,
linewidths=1, linecolor="black",
cbar=False, square=True, alpha=0.9,
annot_kws={"size": 5, "family": "Times New Roman"})
# Change font for tick labels
for text in ax.texts:
text.set_bbox(dict(facecolor='none', edgecolor='none', alpha=0))
plt.xticks(fontsize=4, family="Times New Roman") # X-axis font
plt.yticks(fontsize=4, family="Times New Roman") # Y-axis font
# Enhance Labels and Title
plt.title("Confusion Matrix", fontsize=5, family="Times New Roman",color="black", loc='center')
# Apply transparent background and double border (via Streamlit Markdown)
st.markdown("""
<style>
div[data-testid="stImageContainer"] {
max-height: 250px;
overflow-y: auto;
border-radius: 25px;
text-align:center;
border: 5px solid rgba(173, 216, 230, 0.4);
padding: 10px;
background: rgba(255, 255, 255, 0.05);
box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
width:480px !important;
margin-left:-80px;
margin-top:-20px;
}
div[data-testid="stImageContainer"] img{
margin-top:-10px !important;
width:400px !important;
height:250px !important;
}
[class*="st-key-roc"] div[data-testid="stImageContainer"] {
max-height: 250px;
overflow-y: auto;
border-radius: 25px;
text-align:center;
border: 5px solid rgba(173, 216, 230, 0.4);
background: rgba(255, 255, 255, 0.05);
box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
width:480px;
margin-left:-35px;
margin-top:-15px;
}
[class*="st-key-roc"] div[data-testid="stImageContainer"] img{
width:480px !important;
height:250px !important;
margin-top:-20px !important;
}
[class*="st-key-precision"] div[data-testid="stImageContainer"] {
max-height: 250px;
overflow-y: auto;
border-radius: 25px;
text-align:center;
border: 5px solid rgba(173, 216, 230, 0.4);
background: rgba(255, 255, 255, 0.05);
box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4);
width:480px;
margin-left:-35px;
margin-top:-5px;
}
[class*="st-key-precision"] div[data-testid="stImageContainer"] img{
width:480px !important;
height:250px !important;
margin-top:-20px !important;
}
</style>
""", unsafe_allow_html=True)
# Show Plot in Streamlit inside a styled container
st.markdown('<div class="confusion-matrix-container">', unsafe_allow_html=True)
st.pyplot(fig)
st.markdown("</div>", unsafe_allow_html=True)
with col2:
if st.session_state.framework == "TensorFlow":
# Binarizing the true labels for multi-class classification
y_true_bin = label_binarize(y_true, classes=np.arange(len(class_names)))
# Calculating ROC curve and AUC for each class
fpr, tpr, roc_auc = {}, {}, {}
for i in range(len(class_names)):
fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred_probs[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Plotting ROC curve for each class
plt.figure(figsize=(11, 9))
for i in range(len(class_names)):
plt.plot(fpr[i], tpr[i], lw=2, label=f'{class_names[i]} (AUC = {roc_auc[i]:.2f})')
# Plot random guess line
plt.plot([0, 1], [0, 1], color='navy', lw=5, linestyle='--')
# Labels and legend
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate',fontsize=28,family="Times New Roman")
plt.ylabel('True Positive Rate',fontsize=28,family="Times New Roman")
plt.title('ROC Curve (One-vs-Rest) for Each Class',fontsize=30, family="Times New Roman",color="black", loc='center',pad=3)
plt.legend(loc='lower right',fontsize=18)
# Save the plot as an image
plt.savefig('roc_curve.png', transparent=True)
plt.close()
# Display the plot in Streamlit
with st.container(key="roc"):
st.image('roc_curve.png')
elif st.session_state.framework == "PyTorch":
# Display the ROC curve in Streamlit
with st.container(key="roc"):
st.image('roc-py.png')
with st.container(key="precision"):
st.image('precision_recall_curve.png')
if st.session_state.show_desc:
# components.html(html_string) # JavaScript works
# st.markdown(html_string, unsafe_allow_html=True)
image_path = "image.jpg"
st.container()
st.markdown(
f"""
<div class="titles">
<h1>Kidney Disease Classfication</br> Using Deep learning</h1>
<div> This web application utilizes deep learning to classify kidney ultrasound images</br>
into four categories: Normal, Cyst, Tumor, and Stone Class.
Built with Streamlit and powered by </br>a TensorFlow transfer learning
model based on <strong>CNN</strong>
the app provides a simple and efficient way for users </br>
to upload kidney scans and receive instant predictions. The model analyzes the image
and classifies it based </br>on learned patterns, offering a confidence score for better interpretation.
</div>
</div>
""",
unsafe_allow_html=True,
)
uploaded_file = st.file_uploader(
"Choose a file", type=["png", "jpg", "jpeg"], key="upload-btn"
)
if uploaded_file is not None:
images = Image.open(uploaded_file)
# Rewind file pointer to the beginning
uploaded_file.seek(0)
file_content = uploaded_file.read() # Read file once
# Convert to base64 for HTML display
encoded_image = base64.b64encode(file_content).decode()
# Read and process image
pil_image = Image.open(uploaded_file).convert("RGB").resize((224, 224))
img_array = np.array(pil_image)
prediction = predict_image(images)
if st.session_state.framework == "TensorFlow":
max_index = int(np.argmax(prediction[0]))
print(f"max index:{max_index}")
max_score = prediction[0][max_index]
predicted_class = np.argmax(prediction[0])
highlight_class = "highlight" # Special class for the highest confidence score
# Generate Grad-CAM
cam = generate_gradcam(pil_image, predicted_class)
# Create overlay
heatmap = cm.jet(cam)[..., :3]
heatmap = (heatmap * 255).astype(np.uint8)
overlayed_image = cv2.addWeighted(img_array, 0.6, heatmap, 0.4, 0)
# Convert to PIL
overlayed_pil = Image.fromarray(overlayed_image)
# Convert to base64
orig_b64 = convert_image_to_base64(pil_image)
overlay_b64 = convert_image_to_base64(overlayed_pil)
content = f"""
<div class="content-container">
<!-- Title -->
<!-- Recently Viewed Section -->
<div class="content-container3">
<img src="data:image/png;base64,{orig_b64}" alt="Uploaded Image">
</div>
<div class="content-container3">
<img src="data:image/png;base64,{overlay_b64}" class="result-image">
</div>
<div class="content-container4 {'highlight' if max_index == 0 else ''}">
<h3>{class_labels[0]}</h3>
<p>T Score: {prediction[0][0]:.2f}</p>
</div>
<div class="content-container5 {'highlight' if max_index == 1 else ''}">
<h3> {class_labels[1]}</h3>
<p>T Score: {prediction[0][1]:.2f}</p>
</div>
<div class="content-container6 {'highlight' if max_index == 2 else ''}">
<h3> {class_labels[2]}</h3>
<p>T Score: {prediction[0][2]:.2f}</p>
</div>
<div class="content-container7 {'highlight' if max_index == 3 else ''}">
<h3>{class_labels[3]}</h3>
<p>T Score: {prediction[0][3]:.2f}</p>
</div>
"""
elif st.session_state.framework == "PyTorch":
class0, class1,prediction = predict_image(images)
max_index = int(np.argmax(prediction[0]))
print(f"max index:{max_index}")
max_score = prediction[0][max_index]
predicted_class = np.argmax(prediction[0])
print(f"predicted class is :{predicted_class}")
#cams = generate_gradcams(pil_image, predicted_class)
#heatmap = cm.jet(cams)[..., :3]
#heatmap = (heatmap * 255).astype(np.uint8)
#overlayed_image = cv2.addWeighted(img_array, 0.6, heatmap, 0.4, 0)
# Convert to PIL
#overlayed_pil = Image.fromarray(overlayed_image)
# Convert to base64
orig_b64 = convert_image_to_base64(pil_image)
#overlay_b64 = convert_image_to_base64(overlayed_pil)
highlight_class = "highlight" # Special class for the highest confidence score
# Generate Grad-CAM
# Create overlay
orig_b64 = convert_image_to_base64(pil_image)
content = f"""
<div class="content-container">
<!-- Title -->
<!-- Recently Viewed Section -->
<div class="content-container3">
<img src="data:image/png;base64,{orig_b64}" alt="Uploaded Image">
</div>
<div class="content-container4 {'highlight' if max_index == 0 else ''}">
<h3>{class_labels[0]}</h3>
<p>T Score: {prediction[0][0]:.2f}</p>
</div>
<div class="content-container5 {'highlight' if max_index == 1 else ''}">
<h3> {class_labels[1]}</h3>
<p>T Score: {prediction[0][1]:.2f}</p>
</div>
<div class="content-container6 {'highlight' if max_index == 2 else ''}">
<h3> {class_labels[2]}</h3>
<p>T Score: {prediction[0][2]:.2f}</p>
</div>
<div class="content-container7 {'highlight' if max_index == 3 else ''}">
<h3>{class_labels[3]}</h3>
<p>T Score: {prediction[0][3]:.2f}</p>
</div>
"""
# Render the content
placeholder = st.empty() # Create a placeholder
placeholder.markdown(loading_html, unsafe_allow_html=True)
time.sleep(5) # Wait for 5 seconds
placeholder.empty()
st.markdown(content, unsafe_allow_html=True)
else:
default_image_path = "image.jpg"
with open(image_path, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode()
st.markdown(
f"""
<div class="content-container">
<!-- Title -->
<!-- Recently Viewed Section -->
<div class="content-container3">
<img src="data:image/png;base64,{encoded_image}" alt="Default Image">
</div>
</div>
""",
unsafe_allow_html=True,
)