Spaces:
Sleeping
Sleeping
import cv2 | |
import streamlit as st | |
st.set_page_config(layout="wide") | |
import streamlit.components.v1 as components | |
import time | |
import numpy as np | |
import pandas as pd | |
import tensorflow as tf | |
import matplotlib.pyplot as plt | |
import matplotlib.cm as cm | |
from PIL import Image | |
from tf_keras_vis.gradcam import Gradcam | |
from io import BytesIO | |
from sklearn.metrics import classification_report,confusion_matrix, roc_curve, auc,precision_recall_curve, average_precision_score | |
from sklearn.preprocessing import label_binarize | |
import seaborn as sns | |
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
from torchvision import datasets, transforms | |
import torchvision.transforms as transforms | |
import torch.nn.functional as F | |
from gradcam import GradCAM # Import your GradCAM class | |
if "model" not in st.session_state: | |
st.session_state.model = tf.keras.models.load_model( | |
"best_model.h5" | |
) | |
if "framework" not in st.session_state: | |
st.session_state.framework = "Tensorflow" | |
if "menu" not in st.session_state: | |
st.session_state.menu = "1" | |
if st.session_state.menu =="1": | |
st.session_state.show_summary = True | |
st.session_state.show_arch = False | |
st.session_state.show_desc = False | |
elif st.session_state.menu =="2": | |
st.session_state.show_arch = True | |
st.session_state.show_summary = False | |
st.session_state.show_desc = False | |
elif st.session_state.menu =="3": | |
st.session_state.show_arch = False | |
st.session_state.show_summary = False | |
st.session_state.show_desc = True | |
else: | |
st.session_state.show_desc = True | |
import base64 | |
import os | |
import tf_keras_vis | |
# ****************************************/ | |
# GRAD CAM | |
# *********************************************# | |
if st.session_state.framework == "TensorFlow": | |
gradcam = Gradcam(st.session_state.model, model_modifier=None, clone=False) | |
def generate_gradcam(pil_image, target_class): | |
# Convert PIL to array and preprocess | |
img_array = np.array(pil_image) | |
img_preprocessed = tf.keras.applications.vgg16.preprocess_input(img_array.copy()) | |
img_tensor = tf.expand_dims(img_preprocessed, axis=0) | |
# Generate heatmap | |
loss = lambda output: tf.reduce_mean(output[:, target_class]) | |
cam = gradcam(loss, img_tensor, penultimate_layer=-1) | |
# Process heatmap | |
cam = cam | |
if cam.ndim > 2: | |
cam = cam.squeeze() | |
cam = np.maximum(cam, 0) | |
cam = cv2.resize(cam, (224, 224)) | |
cam = cam / cam.max() if cam.max() > 0 else cam | |
return cam | |
if st.session_state.framework == "PyTorch": | |
target_layer = st.session_state.model.conv3 # Typically last convolutional layer | |
#gradcam = GradCAM(st.session_state.model, target_layer) | |
def preprocess_image(image): | |
preprocess = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor() | |
]) | |
return preprocess(image).unsqueeze(0) # Add batch dimension | |
def generate_gradcams(image, target_class): | |
# Preprocess the image and convert it to a tensor | |
input_image = preprocess_image(image) | |
# Instantiate GradCAM | |
gradcampy = GradCAM(st.session_state.model, target_layer) | |
# Generate the CAM | |
cam = gradcampy.generate(input_image, target_class) | |
return cam | |
def convert_image_to_base64(pil_image): | |
buffered = BytesIO() | |
pil_image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode() | |
#------------------------------------------------- | |
#loading pytorch | |
class KidneyCNN(nn.Module): | |
def __init__(self, num_classes=4): | |
super(KidneyCNN, self).__init__() | |
# Convolutional layers | |
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1) | |
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1) | |
self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1) | |
self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1) | |
# Batch normalization layers | |
self.bn1 = nn.BatchNorm2d(32) | |
self.bn2 = nn.BatchNorm2d(64) | |
self.bn3 = nn.BatchNorm2d(128) | |
self.bn4 = nn.BatchNorm2d(256) | |
# Max pooling layers | |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) | |
# Fully connected layers | |
self.fc1 = nn.Linear(256 * 14 * 14, 512) | |
self.fc2 = nn.Linear(512, num_classes) | |
# Dropout for regularization | |
self.dropout = nn.Dropout(0.5) | |
def forward(self, x): | |
# Conv block 1 | |
x = self.pool(F.relu(self.bn1(self.conv1(x)))) | |
# Conv block 2 | |
x = self.pool(F.relu(self.bn2(self.conv2(x)))) | |
# Conv block 3 | |
x = self.pool(F.relu(self.bn3(self.conv3(x)))) | |
# Conv block 4 | |
x = self.pool(F.relu(self.bn4(self.conv4(x)))) | |
x = x.view(x.size(0), -1) | |
# Fully connected layers | |
x = self.dropout(F.relu(self.fc1(x))) | |
x = self.fc2(x) | |
return x | |
if st.session_state.framework =="PyTorch": | |
st.session_state.model = torch.load('kidney_model .pth', map_location=torch.device('cpu')) | |
st.session_state.model.eval() | |
print(type(st.session_state.model)) | |
#********************************************* | |
# /#*********************************************/ | |
# LOADING TEST DATASET | |
# ************************************************* | |
if st.session_state.framework == "TensorFlow": | |
test_dir = "test" | |
BATCH_SIZE = 32 | |
IMG_SIZE = (224, 224) | |
test_dataset = tf.keras.utils.image_dataset_from_directory( | |
test_dir, shuffle=False, batch_size=BATCH_SIZE, image_size=IMG_SIZE | |
) | |
class_names = test_dataset.class_names | |
def one_hot_encode(image, label): | |
label = tf.one_hot(label, num_classes) | |
return image, label | |
# One-hot encode labels using CategoryEncoding | |
class_labels = class_names | |
# One-hot encode labels using CategoryEncoding | |
# One-hot encode labels using CategoryEncoding | |
num_classes = len(class_names) | |
test_dataset = test_dataset.map(one_hot_encode) | |
elif st.session_state.framework == "PyTorch": | |
test_dir = "test" | |
BATCH_SIZE = 32 | |
IMG_SIZE = (224, 224) | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
test_dataset = datasets.ImageFolder(root='test', transform=transform) | |
class_names = test_dataset.classes | |
# One-hot encode labels using CategoryEncoding | |
class_labels = class_names | |
# One-hot encode labels using CategoryEncoding | |
# One-hot encode labels using CategoryEncoding | |
num_classes = len(class_names) | |
####################################################### | |
# --------------------------------------------------# | |
class_labels = ["Cyst", "Normal", "Stone", "Tumor"] | |
def load_tensorflow_model(): | |
tf_model = tf.keras.models.load_model("best_model.h5") | |
return tf_model | |
if st.session_state.framework =="TensorFlow": | |
def predict_image(image): | |
time.sleep(2) | |
image = image.resize((224, 224)) | |
image = np.expand_dims(image, axis=0) | |
predictions = st.session_state.model.predict(image) | |
return predictions | |
if st.session_state.framework == "PyTorch": | |
logo_path = "pytorch.png" | |
bg_color = "#FF5733" # For example, a warm red/orange | |
bg_color_iv = "orange" # For example, a warm red/orange | |
model = "TENSORFLOW" | |
def predict_image(image): | |
# Preprocess the image to match the model input requirements | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Standard VGG16 normalization | |
]) | |
image = transform(image).unsqueeze(0) # Add batch dimension | |
# Move image to the same device as the model (GPU or CPU) | |
image = image | |
# Set the model to evaluation mode | |
st.session_state.model.eval() | |
with torch.no_grad(): # Disable gradient calculation | |
outputs = st.session_state.model(image) # Forward pass | |
# Get predicted probabilities (softmax for multi-class) | |
if outputs.shape[1] == 1: | |
probs = torch.sigmoid(outputs) # Apply sigmoid activation for binary classification | |
prob_class_1 = probs[0].item() # Probability for class 1 | |
prob_class_0 = 1 - prob_class_1 # Probability for class 0 | |
# If the output has two units (binary classification with softmax) | |
else: | |
probs = torch.nn.functional.softmax(outputs, dim=1) | |
prob_class_0 = probs[0, 0].item() | |
prob_class_1 = probs[0, 1].item() | |
# Get the predicted class | |
print("Raw model output (logits):", outputs) | |
return prob_class_0, prob_class_1, probs | |
else: | |
logo_path = "tensorflow.png" | |
bg_color = "orange" # For example, a warm red/orange | |
bg_color_iv = "#FF5733" # For example, a warm red/orange | |
model = "PYTORCH" | |
#/*******************loading pytorch summary | |
def get_layers_data(model, prefix=""): | |
layers_data = [] | |
for name, layer in model.named_children(): # Iterate over layers | |
full_name = f"{prefix}.{name}" if prefix else name # Track hierarchy | |
try: | |
shape = str(list(layer.parameters())[0].shape) # Get shape of the first param | |
except Exception: | |
shape = "N/A" | |
param_count = sum(p.numel() for p in layer.parameters()) # Count parameters | |
layers_data.append((full_name, layer.__class__.__name__, shape, f"{param_count:,}")) | |
# Recursively get layers inside this layer (for nested structures) | |
layers_data.extend(get_layers_data(layer, full_name)) | |
return layers_data | |
########################################### | |
main_bg_ext = "png" | |
main_bg = "bg1.jpg" | |
# Read and encode the logo image | |
with open(logo_path, "rb") as image_file: | |
encoded_logo = base64.b64encode(image_file.read()).decode() | |
# Custom CSS to style the logo above the sidebar | |
st.markdown( | |
f""" | |
<style> | |
/* Container for logo and text */ | |
.logo-text-container {{ | |
position: fixed; | |
top: 20px; /* Adjust vertical position */ | |
left: 30px; /* Align with sidebar */ | |
display: flex; | |
align-items: center; | |
gap: 5px; | |
width: 70%; | |
z-index:1000; | |
}} | |
/* Logo styling */ | |
.logo-text-container img {{ | |
width: 50px; /* Adjust logo size */ | |
border-radius: 10px; /* Optional: round edges */ | |
margin-top:-10px; | |
margin-left:-5px; | |
}} | |
/* Bold text styling */ | |
.logo-text-container h1 {{ | |
font-family: Nunito; | |
color: #0175C2; | |
font-size: 28px; | |
font-weight: bold; | |
margin-right :100px; | |
padding:0px; | |
}} | |
.logo-text-container i{{ | |
font-family: Nunito; | |
color: {bg_color}; | |
font-size: 15px; | |
margin-right :10px; | |
padding:0px; | |
margin-left:-18.5%; | |
margin-top:1%; | |
}} | |
/* Sidebar styling */ | |
section[data-testid="stSidebar"][aria-expanded="true"] {{ | |
margin-top: 100px !important; /* Space for the logo */ | |
border-radius: 0 60px 0px 60px !important; /* Top-left and bottom-right corners */ | |
width: 200px !important; /* Sidebar width */ | |
background:none; /* Gradient background */ | |
/* box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2); /* Shadow effect */ | |
/* border: 1px solid #FFD700; /* Shiny golden border */ | |
margin-bottom: 1px !important; | |
color:white !important; | |
}} | |
[class*="st-key-header"]{{ | |
}} | |
header[data-testid="stHeader"] {{ | |
/*background: transparent !important;*/ | |
background: rgba(255, 255, 255, 0.05); | |
backdrop-filter: blur(10px); | |
/*margin-right: 10px !important;*/ | |
margin-top: 0.5px !important; | |
z-index: 1 !important; | |
color: orange; /* White text */ | |
font-family: "Times New Roman " !important; /* Font */ | |
font-size: 18px !important; /* Font size */ | |
font-weight: bold !important; /* Bold text */ | |
padding: 10px 20px; /* Padding for buttons */ | |
border: none; /* Remove border */ | |
border-radius: 1px; /* Rounded corners */ | |
box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.2); /* Shadow effect */ | |
transition: all 0.3s ease-in-out; /* Smooth transition */ | |
align-items: left; | |
justify-content: center; | |
/*margin: 10px 0;*/ | |
width:100%; | |
height:80px; | |
backdrop-filter: blur(10px); | |
border: 2px solid rgba(255, 255, 255, 0.4); /* Light border */ | |
}} | |
div[data-testid="stDecoration"]{{ | |
background-image:none; | |
}} | |
div[data-testid="stApp"]{{ | |
/*background: grey;*/ | |
background: rgba(255, 255, 255, 0.5); /* Semi-transparent white background */ | |
height: 100vh; /* Full viewport height */ | |
width: 99.5%; | |
border-radius: 2px !important; | |
margin-left:5px; | |
margin-right:5px; | |
margin-top:0px; | |
/* box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.2); /* Shadow effect */ | |
background: url(data:image/{main_bg_ext};base64,{base64.b64encode(open(main_bg, "rb").read()).decode()}); | |
background-size: cover; /* Ensure the image covers the full page */ | |
background-position: center; | |
overflow: hidden; | |
}} | |
.content-container {{ | |
background: rgba(255, 255, 255, 0.05); | |
backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px; | |
width: 28%; | |
margin-left: 150px; | |
/* margin-top: -60px;*/ | |
margin-bottom: 10px; | |
margin-right:10px; | |
padding:0; | |
/* border-radius:0px 0px 15px 15px ;*/ | |
border:1px solid transparent; | |
overflow-y: auto; /* Enable vertical scrolling for the content */ | |
position: fixed; /* Fix the position of the container */ | |
top: 10%; /* Adjust top offset */ | |
left: 60%; /* Adjust left offset */ | |
height: 89.5vh; /* Full viewport height */ | |
}} | |
.content-container-principal img{{ | |
margin-top:260px; | |
margin-left:30px; | |
}} | |
.content-container-principal | |
{{ | |
background-color: rgba(173, 216, 230, 0.5); /* Light blue with 50% transparency */ | |
backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px; | |
width: 20%; | |
/* margin-top: -60px;*/ | |
margin-bottom: 10px; | |
margin-right:10px; | |
margin:10px; | |
/* border-radius:0px 0px 15px 15px ;*/ | |
border:1px solid transparent; | |
overflow-y: auto; /* Enable vertical scrolling for the content */ | |
position: fixed; /* Fix the position of the container */ | |
top: 7%; /* Adjust top offset */ | |
/*left: 2%; Adjust left offset */ | |
height: 84vh; /* Full viewport height */ | |
}} | |
.content-container-principal-in | |
{{ | |
background-color: rgba(173, 216, 230, 0.1); /* Light blue with 50% transparency */ | |
backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px; | |
width: 100%; | |
/* margin-top: -60px;*/ | |
margin:1px; | |
/* border-radius:0px 0px 15px 15px ;*/ | |
border:1px solid transparent; | |
overflow-y: auto; /* Enable vertical scrolling for the content */ | |
position: fixed; /* Fix the position of the container */ | |
height: 100.5vh; /* Full viewport height */ | |
left:0%; | |
top:5%; | |
}} | |
div[data-testid="stText"] {{ | |
background-color: transparent; | |
backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px; | |
width: 132% !important; | |
background-color: rgba(173, 216, 230, 0.1); /* Light blue with 50% transparency */ | |
margin-top: -36px; | |
margin-bottom: 10px; | |
margin-left:-220px !important; | |
padding:50px; | |
padding-bottom:20px; | |
padding-top:50px; | |
/* border-radius:0px 0px 15px 15px ;*/ | |
border:1px solid transparent; | |
overflow-y: auto; /* Enable vertical scrolling for the content */ | |
height: 85vh; !important; /* Full viewport height */ | |
}} | |
.content-container2 {{ | |
background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */ | |
backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px; | |
width: 90%; | |
margin-left: 10px; | |
/* margin-top: -10px;*/ | |
margin-bottom: 160px; | |
margin-right:10px; | |
padding:0; | |
border-radius:1px ; | |
border:1px solid transparent; | |
overflow-y: auto; /* Enable vertical scrolling for the content */ | |
position: fixed; /* Fix the position of the container */ | |
top: 3%; /* Adjust top offset */ | |
left: 2.5%; /* Adjust left offset */ | |
height: 78vh; /* Full viewport height */ | |
}} | |
.content-container4 {{ | |
background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */ | |
backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%; | |
margin-left: 10px; | |
margin-bottom: 160px; | |
margin-right:10px; | |
padding:0; | |
overflow-y: auto; /* Enable vertical scrolling for the content */ | |
position: fixed; /* Fix the position of the container */ | |
top: 60%; /* Adjust top offset */ | |
left: 2.5%; /* Adjust left offset */ | |
height: 10vh; /* Full viewport height */ | |
}} | |
.content-container4 h3 ,p {{ | |
font-family: "Times New Roman" !important; /* Elegant font for title */ | |
font-size: 1rem; | |
font-weight: bold; | |
text-align:center; | |
}} | |
.content-container5 h3 ,p {{ | |
font-family: "Times New Roman" !important; /* Elegant font for title */ | |
font-size: 1rem; | |
font-weight: bold; | |
text-align:center; | |
}} | |
.content-container6 h3 ,p {{ | |
font-family: "Times New Roman" !important; /* Elegant font for title */ | |
font-size: 1rem; | |
font-weight: bold; | |
text-align:center; | |
}} | |
.content-container7 h3 ,p {{ | |
font-family: "Times New Roman" !important; /* Elegant font for title */ | |
font-size: 1rem; | |
font-weight: bold; | |
text-align:center; | |
}} | |
.content-container5 {{ | |
background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */ | |
backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%; | |
margin-left: 180px; | |
margin-bottom: 130px; | |
margin-right:10px; | |
padding:0; | |
overflow-y: auto; /* Enable vertical scrolling for the content */ | |
position: fixed; /* Fix the position of the container */ | |
top: 60%; /* Adjust top offset */ | |
left: 5.5%; /* Adjust left offset */ | |
height: 10vh; /* Full viewport height */ | |
}} | |
.content-container3 {{ | |
background-color: rgba(216, 216, 230, 0.5); /* Light blue with 50% transparency */ | |
backdrop-filter: blur(10px); /* Adds a slight blur effect */ border-radius: 1px; | |
width: 92%; | |
margin-left: 10px; | |
/* margin-top: -10px;*/ | |
margin-bottom: 160px; | |
margin-right:10px; | |
padding:0; | |
border: 10px solid white; | |
overflow-y: auto; /* Enable vertical scrolling for the content */ | |
position: fixed; /* Fix the position of the container */ | |
top: 3%; /* Adjust top offset */ | |
left: 1.5%; /* Adjust left offset */ | |
height: 40vh; /* Full viewport height */ | |
}} | |
.content-container6 {{ | |
background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */ | |
backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%; | |
margin-left: 10px; | |
margin-bottom: 160px; | |
margin-right:10px; | |
padding:0; | |
overflow-y: auto; /* Enable vertical scrolling for the content */ | |
position: fixed; /* Fix the position of the container */ | |
top: 80%; /* Adjust top offset */ | |
left: 2.5%; /* Adjust left offset */ | |
height: 10vh; /* Full viewport height */ | |
}} | |
.content-container7 {{ | |
background-color: rgba(0, 0, 0, 0.1); /* Light blue with 50% transparency */ | |
backdrop-filter: blur(10px); /* Adds a slight blur effect */ width: 40%; | |
margin-left: 180px; | |
margin-bottom: 130px; | |
margin-right:10px; | |
padding:0; | |
overflow-y: auto; /* Enable vertical scrolling for the content */ | |
position: fixed; /* Fix the position of the container */ | |
top: 80%; /* Adjust top offset */ | |
left: 5.5%; /* Adjust left offset */ | |
height: 10vh; /* Full viewport height */ | |
}} | |
.content-container2 img {{ | |
width:99%; | |
height:50%; | |
}} | |
.content-container3 img {{ | |
width:100%; | |
height:100%; | |
}} | |
div.stButton > button {{ | |
background: rgba(255, 255, 255, 0.2); | |
color: orange !important; /* White text */ | |
font-family: "Times New Roman " !important; /* Font */ | |
font-size: 18px !important; /* Font size */ | |
font-weight: bold !important; /* Bold text */ | |
padding: 1px 2px; /* Padding for buttons */ | |
border: none; /* Remove border */ | |
border-radius: 5px; /* Rounded corners */ | |
box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.2); /* Shadow effect */ | |
transition: all 0.3s ease-in-out; /* Smooth transition */ | |
display: flex; | |
align-items: left; | |
justify-content: left; | |
margin-left:-50px ; | |
width:250px; | |
height:50px; | |
backdrop-filter: blur(10px); | |
z-index:1000; | |
text-align: left; /* Align text to the left */ | |
padding-left: 50px; | |
}} | |
div.stButton > button p{{ | |
color: {bg_color} !important; /* White text */ | |
}} | |
/* Hover effect */ | |
div.stButton > button:hover {{ | |
background: rgba(255, 255, 255, 0.2); | |
box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.4); /* Enhanced shadow on hover */ | |
transform: scale(1.05); /* Slightly enlarge button */ | |
transform: scale(1.1); /* Slight zoom on hover */ | |
box-shadow: 0px 4px 12px rgba(255, 255, 255, 0.4); /* Glow effect */ | |
}} | |
div.stButton > button:active {{ | |
background: rgba(199, 107, 26, 0.5); | |
box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.4); /* Enhanced shadow on hover */ | |
}} | |
.titles{{ | |
margin-top:20px !important; | |
margin-left: -150px !important; | |
}} | |
/* Title styling */ | |
.titles h1{{ | |
/*font-family: "Times New Roman" !important; /* Elegant font for title */ | |
font-size: 1.9rem; | |
/*font-weight: bold;*/ | |
margin-left: 5px; | |
/* margin-top:-50px;*/ | |
margin-bottom:50px; | |
padding: 0; | |
color: black; /* Neutral color for text */ | |
}} | |
.titles > div{{ | |
font-family: "Times New Roman" !important; /* Elegant font for title */ | |
font-size: 1.01rem; | |
margin-left: -50px; | |
margin-bottom:1px; | |
padding: 0; | |
color:black; /* Neutral color for text */ | |
}} | |
/* Recently viewed section */ | |
.recently-viewed {{ | |
display: flex; | |
align-items: center; | |
justify-content: flex-start; /* Align items to the extreme left */ | |
margin-bottom: 10px; | |
margin-top: 20px; | |
gap: 10px; /* Add spacing between the elements */ | |
padding-left: 20px; /* Add some padding if needed */ | |
margin-left:35px; | |
height:100px; | |
}} | |
/* Style for the upload button */ | |
[class*="st-key-upload-btn"] {{ | |
position: absolute; | |
top: 100%; /* Position from the top of the inner circle */ | |
left: -26%; /* Position horizontally at the center */ | |
padding: 10px 20px; | |
color: red; | |
border: none; | |
border-radius: 20px; | |
cursor: pointer; | |
font-size: 35px !important; | |
width:30px; | |
height:20px; | |
}} | |
.upload-btn:hover {{ | |
background-color: rgba(0, 123, 255, 1); | |
}} | |
div[data-testid="stFileUploader"] label > div > p {{ | |
display:none; | |
color:white !important; | |
}} | |
section[data-testid="stFileUploaderDropzone"] {{ | |
width:200px; | |
height: 60px; | |
background-color: white; | |
border-radius: 40px; | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
margin-top:-10px; | |
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.3); | |
margin:20px; | |
background-color: rgba(255, 255, 255, 0.7); /* Transparent blue background */ | |
color:white; | |
}} | |
div[data-testid="stFileUploaderDropzoneInstructions"] div > small{{ | |
color:white !important; | |
display:none; | |
}} | |
div[data-testid="stFileUploaderDropzoneInstructions"] span{{ | |
margin-left:65px; | |
color:{bg_color}; | |
}} | |
div[data-testid="stFileUploaderDropzoneInstructions"] div{{ | |
display:none; | |
}} | |
section[data-testid="stFileUploaderDropzone"] button{{ | |
display:none; | |
}} | |
div[data-testid="stMarkdownContainer"] p {{ | |
font-family: "Times New Roman" !important; /* Elegant font for title */ | |
color:white !important; | |
}} | |
.highlight {{ | |
border: 4px solid lime; | |
font-weight: bold; | |
background: radial-gradient(circle, rgba(0,255,0,0.3) 0%, rgba(0,0,0,0) 70%); | |
box-shadow: 0px 0px 30px 10px rgba(0, 255, 0, 0.9), | |
0px 0px 60px 20px rgba(0, 255, 0, 0.6), | |
inset 0px 0px 15px rgba(0, 255, 0, 0.8); | |
transition: all 0.3s ease-in-out; | |
}} | |
.highlight:hover {{ | |
transform: scale(1.05); | |
background: radial-gradient(circle, rgba(0,255,0,0.6) 0%, rgba(0,0,0,0) 80%); | |
box-shadow: 0px 0px 40px 15px rgba(0, 255, 0, 1), | |
0px 0px 70px 30px rgba(0, 255, 0, 0.7), | |
inset 0px 0px 20px rgba(0, 255, 0, 1); | |
}} | |
.stCheckbox > label > div{{ | |
width:303px !important; | |
height:3rem; | |
margin-top:270px; | |
margin-left:-72px; | |
border-radius:1px !important; | |
}} | |
.st-b1 {{ | |
width:1.75rem; | |
height:1.75rem; | |
display:none; | |
}} | |
.stCheckbox > label > div:after {{ | |
content: "SWITCH TO {model} MODEL"; | |
display: block; | |
font-family: "Times New Roman", serif; | |
margin-top: 0.5em; | |
margin-left:20px; | |
font-weight:bold; | |
}} | |
.st-bj{{ | |
display:none; | |
}} | |
.stCheckbox label{{ | |
height:0px; | |
}} | |
.stCheckbox > label > div {{ | |
background:{bg_color_iv} !important; | |
}} | |
</style> | |
<div class="logo-text-container"> | |
<img src="data:image/png;base64,{encoded_logo}" alt="Logo"> | |
<h1>KidneyScan AI<br> | |
</h1> | |
<i>Empowering Early Diagnosis with AI</ai> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
loading_html = """ | |
<style> | |
.loader { | |
border: 8px solid #f3f3f3; | |
border-top: 8px solid #0175C2; /* Blue color */ | |
border-radius: 50%; | |
width: 50px; | |
height: 50px; | |
animation: spin 1s linear infinite; | |
margin: auto; | |
} | |
@keyframes spin { | |
0% { transform: rotate(0deg); } | |
100% { transform: rotate(360deg); } | |
} | |
</style> | |
<div class="loader"></div> | |
""" | |
# Sidebar content | |
# Use radio buttons for navigation | |
page = "pome" | |
# Sidebar buttons | |
# Display content based on the selected page | |
# Define the page content dynamically | |
if page == "Home": | |
# components.html(html_string) # JavaScript works | |
# st.markdown(html_string, unsafe_allow_html=True) | |
image_path = "image.jpg" | |
st.container() | |
st.markdown( | |
f""" | |
<div class="titles"> | |
<h1>Kidney Disease Classfication</br> Using Transfer learning</h1> | |
<div> This web application utilizes deep learning to classify kidney ultrasound images</br> | |
into four categories: Normal, Cyst, Tumor, and Stone Class. | |
Built with Streamlit and powered by </br>a TensorFlow transfer learning | |
model based on <strong>VGG16</strong> | |
the app provides a simple and efficient way for users </br> | |
to upload kidney scans and receive instant predictions. The model analyzes the image | |
and classifies it based </br>on learned patterns, offering a confidence score for better interpretation. | |
</div> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
uploaded_file = st.file_uploader( | |
"Choose a file", type=["png", "jpg", "jpeg"], key="upload-btn" | |
) | |
if uploaded_file is not None: | |
images = Image.open(uploaded_file) | |
# Rewind file pointer to the beginning | |
uploaded_file.seek(0) | |
file_content = uploaded_file.read() # Read file once | |
# Convert to base64 for HTML display | |
encoded_image = base64.b64encode(file_content).decode() | |
# Read and process image | |
pil_image = Image.open(uploaded_file).convert("RGB").resize((224, 224)) | |
img_array = np.array(pil_image) | |
prediction = predict_image(images) | |
max_index = int(np.argmax(prediction[0])) | |
print(f"max index:{max_index}") | |
max_score = prediction[0][max_index] | |
predicted_class = np.argmax(prediction[0]) | |
highlight_class = "highlight" # Special class for the highest confidence score | |
# Generate Grad-CAM | |
cam = generate_gradcam(pil_image, predicted_class) | |
# Create overlay | |
heatmap = cm.jet(cam)[..., :3] | |
heatmap = (heatmap * 255).astype(np.uint8) | |
overlayed_image = cv2.addWeighted(img_array, 0.6, heatmap, 0.4, 0) | |
# Convert to PIL | |
overlayed_pil = Image.fromarray(overlayed_image) | |
# Convert to base64 | |
orig_b64 = convert_image_to_base64(pil_image) | |
overlay_b64 = convert_image_to_base64(overlayed_pil) | |
content = f""" | |
<div class="content-container"> | |
<!-- Title --> | |
<!-- Recently Viewed Section --> | |
<div class="content-container2"> | |
<div class="content-container3"> | |
<img src="data:image/png;base64,{orig_b64}" alt="Uploaded Image"> | |
</div> | |
<div class="content-container3"> | |
<img src="data:image/png;base64,{overlay_b64}" class="result-image"> | |
</div> | |
<div class="content-container4 {'highlight' if max_index == 0 else ''}"> | |
<h3>{class_labels[0]}</h3> | |
<p>T Score: {prediction[0][0]:.2f}</p> | |
</div> | |
<div class="content-container5 {'highlight' if max_index == 1 else ''}"> | |
<h3> {class_labels[1]}</h3> | |
<p>T Score: {prediction[0][1]:.2f}</p> | |
</div> | |
<div class="content-container6 {'highlight' if max_index == 2 else ''}"> | |
<h3> {class_labels[2]}</h3> | |
<p>T Score: {prediction[0][2]:.2f}</p> | |
</div> | |
<div class="content-container7 {'highlight' if max_index == 3 else ''}"> | |
<h3>{class_labels[3]}</h3> | |
<p>T Score: {prediction[0][3]:.2f}</p> | |
</div> | |
""" | |
# Close the gallery and content div | |
# Render the content | |
placeholder = st.empty() # Create a placeholder | |
placeholder.markdown(loading_html, unsafe_allow_html=True) | |
time.sleep(5) # Wait for 5 seconds | |
placeholder.empty() | |
st.markdown(content, unsafe_allow_html=True) | |
else: | |
default_image_path = "image.jpg" | |
with open(image_path, "rb") as image_file: | |
encoded_image = base64.b64encode(image_file.read()).decode() | |
st.markdown( | |
f""" | |
<div class="content-container"> | |
<!-- Title --> | |
<!-- Recently Viewed Section --> | |
<div class="content-container2"> | |
<div class="content-container3"> | |
<img src="data:image/png;base64,{encoded_image}" alt="Default Image"> | |
</div> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
if page == "pome": | |
gif_path = "bg3.gif" | |
with open(gif_path, "rb") as image_file: | |
encode_image = base64.b64encode(image_file.read()).decode() | |
st.markdown( | |
f""" | |
<div class="content-container-principal-in"> | |
<div class="content-container-principal"> | |
<img src="data:image/png;base64,{encode_image}" alt="Default Image"> | |
</div> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
col1, col2 = st.columns([1, 2]) # Adjust column widths | |
with col1: | |
if st.button("📄 Model Summary"): | |
st.session_state.menu ="1" # Store state | |
st.rerun() | |
# Add your model description logic here | |
if st.button("📊 Model Results Analysis",key="header"): | |
st.session_state.menu ="2" | |
st.rerun() | |
# Add model analysis logic here | |
if st.button("🧪 Model Testing"): | |
st.session_state.menu ="3" | |
st.rerun() | |
# Toggle switch UI | |
def framework_toggle(): | |
toggle = st.toggle("Enable PyTorch", value=(st.session_state.framework == "PyTorch")) | |
if toggle and st.session_state.framework != "PyTorch": | |
st.session_state.framework = "PyTorch" | |
st.session_state.model = torch.load('kidney_model .pth', map_location=torch.device('cpu')) | |
st.rerun() | |
elif not toggle and st.session_state.framework != "TensorFlow": | |
st.session_state.framework = "TensorFlow" | |
st.session_state.model = tf.keras.models.load_model( | |
"best_model.h5" | |
) | |
st.rerun() | |
print(st.session_state.framework) | |
framework_toggle() | |
# Custom CSS for table styling | |
table_style = """ | |
<style> | |
table { | |
width: 110%; | |
border-collapse: collapse; | |
border-radius: 2px; | |
overflow: hidden; | |
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.4); | |
background: rgba(255, 255, 255, 0.05); | |
backdrop-filter: blur(10px); | |
font-family: "Times New Roman", serif; | |
margin-left:-100px; | |
margin-top:10px; | |
} | |
thead { | |
background: rgba(255, 255, 255, 0.2); | |
} | |
th { | |
padding: 12px; | |
text-align: left; | |
font-weight: bold; | |
backdrop-filter: blur(10px); | |
} | |
td { | |
padding: 12px; | |
border-bottom: 1px solid rgba(255, 255, 255, 0.1); | |
} | |
tr:hover { | |
background-color: rgba(255, 255, 255, 0.1); | |
} | |
tbody { | |
display: block; | |
max-height: 580px; /* Set the fixed height */ | |
overflow-y: auto; | |
width: 100%; | |
} | |
thead, tbody tr { | |
display: table; | |
width: 100%; | |
table-layout: fixed; | |
} | |
</style> | |
""" | |
with col2: | |
if st.session_state.show_summary: | |
layers_data = [] | |
print(st.session_state) | |
if st.session_state.framework == "TensorFlow": | |
for layer in st.session_state.model.layers: | |
try: | |
shape = {layer.output.shape} | |
except Exception: | |
shape = "N/A" | |
if isinstance(shape, tuple): | |
shape = str(shape) | |
elif isinstance(shape, list): | |
shape = ", ".join(str(s) for s in shape) | |
elif shape is None: | |
shape = "N/A" | |
param_count = f"{layer.count_params():,}" | |
layers_data.append( | |
(layer.name, layer.__class__.__name__, shape, param_count) | |
) | |
print(layers_data) | |
elif st.session_state.framework == "PyTorch": | |
layers_data = get_layers_data(st.session_state.model) # Get layer information | |
# Convert to HTML table | |
table_html = "<table><tr><th>Layer Name</th><th>Type</th><th>Output Shape</th><th>Param #</th></tr>" | |
for name, layer_type, shape, params in layers_data: | |
table_html += f"<tr><td>{name}</td><td>{layer_type}</td><td>{shape}</td><td>{params}</td></tr>" | |
table_html += "</table>" | |
# Render table with custom styling | |
st.markdown(table_style + table_html, unsafe_allow_html=True) | |
if st.session_state.show_arch: | |
if st.session_state.framework == "TensorFlow": | |
y_true = np.concatenate([y.numpy() for _, y in test_dataset]) | |
# Get model predictions | |
y_pred_probs = st.session_state.model.predict(test_dataset) | |
y_pred = np.argmax(y_pred_probs, axis=1) | |
# Convert one-hot true labels to class indices | |
y_true = np.argmax(y_true, axis=1) | |
# Class names (modify for your dataset) | |
class_names = ["Cyst", "Normal", "Stone", "Tumor"] | |
# Generate classification report as a dictionary | |
report_dict = classification_report(y_true, y_pred, target_names=class_names, output_dict=True) | |
# Convert to DataFrame | |
report_df = pd.DataFrame(report_dict).transpose().round(2) | |
accuracy = report_dict["accuracy"] | |
precision = report_df.loc["weighted avg", "precision"] | |
recall = report_df.loc["weighted avg", "recall"] | |
f1_score = report_df.loc["weighted avg", "f1-score"] | |
elif st.session_state.framework == "PyTorch": | |
y_true = [] | |
y_pred = [] | |
for image, label in test_dataset: # test_dataset is an instance of ImageFolder or similar | |
image = image.unsqueeze(0) # Add batch dimension and move to device | |
label = label | |
with torch.no_grad(): | |
output = st.session_state.model(image) # Get model output | |
_, predicted = torch.max(output, 1) # Get predicted class | |
y_true.append(label) # Append true label | |
y_pred.append(predicted.item()) # Append predicted label | |
# Generate the classification report | |
report_dict = classification_report(y_true, y_pred, target_names=class_names, output_dict=True) | |
# Convert to DataFrame for better readability | |
report_df = pd.DataFrame(report_dict).transpose().round(2) | |
accuracy = report_dict["accuracy"] | |
precision = report_df.loc["weighted avg", "precision"] | |
recall = report_df.loc["weighted avg", "recall"] | |
f1_score = report_df.loc["weighted avg", "f1-score"] | |
st.markdown(""" | |
<style> | |
.kpi-container { | |
display: flex; | |
justify-content: space-between; | |
margin-bottom: 20px; | |
margin-left:-80px; | |
margin-top:-30px; | |
} | |
.kpi-card { | |
width: 23%; | |
padding: 15px; | |
text-align: center; | |
border-radius: 10px; | |
font-size: 22px; | |
font-weight: bold; | |
font-family: "Times New Roman " !important; /* Font */ | |
color: #333; | |
background: rgba(255, 255, 255, 0.05); | |
box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4); | |
border: 5px solid rgba(173, 216, 230, 0.4); | |
} | |
</style> | |
<div class="kpi-container"> | |
<div class="kpi-card">Precision<br>""" + f"{precision:.2f}" + """</div> | |
<div class="kpi-card">Recall<br>""" + f"{recall:.2f}" + """</div> | |
<div class="kpi-card">Accuracy<br>""" + f"{accuracy:.2f}" + """</div> | |
<div class="kpi-card">F1-Score<br>""" + f"{f1_score:.2f}" + """</div> | |
</div> | |
""", unsafe_allow_html=True) | |
# Remove last rows (accuracy/macro avg/weighted avg) and reset index | |
report_df = report_df.iloc[:-3].reset_index() | |
report_df.rename(columns={"index": "Class"}, inplace=True) | |
# Custom CSS for Table Styling | |
st.markdown(""" | |
<style> | |
.report-container { | |
max-height: 250px; | |
overflow-y: auto; | |
border-radius: 25px; | |
text-align:center; | |
border: 5px solid rgba(173, 216, 230, 0.4); | |
padding: 10px; | |
background: rgba(255, 255, 255, 0.05); | |
box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4); | |
width:480px; | |
margin-left:-80px; | |
margin-top:-20px; | |
} | |
.report-container h4{ | |
font-family: "Times New Roman" !important; /* Elegant font for title */ | |
font-size: 1rem; | |
margin-left: 5px; | |
margin-bottom:1px; | |
padding: 10px; | |
color:#333; | |
} | |
.report-table { | |
width: 100%; | |
border-collapse: collapse; | |
font-family: 'Times New Roman', serif; | |
text-align: center; | |
} | |
.report-table th { | |
background: rgba(255, 255, 255, 0.05); | |
font-size: 16px; | |
padding: 10px; | |
border-bottom: 2px solid #444; | |
} | |
.report-table td { | |
font-size: 12px; | |
padding: 10px; | |
border-bottom: 1px solid #ddd; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
col1,col2 = st.columns([3,3]) | |
with col1: | |
# Convert DataFrame to HTML Table | |
report_html = report_df.to_html(index=False, classes="report-table", escape=False) | |
st.markdown(f'<div class="report-container"><h4>classification report </h4>{report_html}</div>', unsafe_allow_html=True) | |
# Generate Confusion Matrix | |
# Generate Confusion Matrix | |
cm = confusion_matrix(y_true, y_pred) | |
# Create Confusion Matrix Heatmap | |
fig, ax = plt.subplots(figsize=(1, 1)) | |
fig.patch.set_alpha(0) # Make figure background transparent | |
# Seaborn Heatmap (Confusion Matrix) | |
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", | |
xticklabels=class_names, yticklabels=class_names, | |
linewidths=1, linecolor="black", | |
cbar=False, square=True, alpha=0.9, | |
annot_kws={"size": 5, "family": "Times New Roman"}) | |
# Change font for tick labels | |
for text in ax.texts: | |
text.set_bbox(dict(facecolor='none', edgecolor='none', alpha=0)) | |
plt.xticks(fontsize=4, family="Times New Roman") # X-axis font | |
plt.yticks(fontsize=4, family="Times New Roman") # Y-axis font | |
# Enhance Labels and Title | |
plt.title("Confusion Matrix", fontsize=5, family="Times New Roman",color="black", loc='center') | |
# Apply transparent background and double border (via Streamlit Markdown) | |
st.markdown(""" | |
<style> | |
div[data-testid="stImageContainer"] { | |
max-height: 250px; | |
overflow-y: auto; | |
border-radius: 25px; | |
text-align:center; | |
border: 5px solid rgba(173, 216, 230, 0.4); | |
padding: 10px; | |
background: rgba(255, 255, 255, 0.05); | |
box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4); | |
width:480px !important; | |
margin-left:-80px; | |
margin-top:-20px; | |
} | |
div[data-testid="stImageContainer"] img{ | |
margin-top:-10px !important; | |
width:400px !important; | |
height:250px !important; | |
} | |
[class*="st-key-roc"] div[data-testid="stImageContainer"] { | |
max-height: 250px; | |
overflow-y: auto; | |
border-radius: 25px; | |
text-align:center; | |
border: 5px solid rgba(173, 216, 230, 0.4); | |
background: rgba(255, 255, 255, 0.05); | |
box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4); | |
width:480px; | |
margin-left:-35px; | |
margin-top:-15px; | |
} | |
[class*="st-key-roc"] div[data-testid="stImageContainer"] img{ | |
width:480px !important; | |
height:250px !important; | |
margin-top:-20px !important; | |
} | |
[class*="st-key-precision"] div[data-testid="stImageContainer"] { | |
max-height: 250px; | |
overflow-y: auto; | |
border-radius: 25px; | |
text-align:center; | |
border: 5px solid rgba(173, 216, 230, 0.4); | |
background: rgba(255, 255, 255, 0.05); | |
box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.4); | |
width:480px; | |
margin-left:-35px; | |
margin-top:-5px; | |
} | |
[class*="st-key-precision"] div[data-testid="stImageContainer"] img{ | |
width:480px !important; | |
height:250px !important; | |
margin-top:-20px !important; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Show Plot in Streamlit inside a styled container | |
st.markdown('<div class="confusion-matrix-container">', unsafe_allow_html=True) | |
st.pyplot(fig) | |
st.markdown("</div>", unsafe_allow_html=True) | |
with col2: | |
if st.session_state.framework == "TensorFlow": | |
# Binarizing the true labels for multi-class classification | |
y_true_bin = label_binarize(y_true, classes=np.arange(len(class_names))) | |
# Calculating ROC curve and AUC for each class | |
fpr, tpr, roc_auc = {}, {}, {} | |
for i in range(len(class_names)): | |
fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred_probs[:, i]) | |
roc_auc[i] = auc(fpr[i], tpr[i]) | |
# Plotting ROC curve for each class | |
plt.figure(figsize=(11, 9)) | |
for i in range(len(class_names)): | |
plt.plot(fpr[i], tpr[i], lw=2, label=f'{class_names[i]} (AUC = {roc_auc[i]:.2f})') | |
# Plot random guess line | |
plt.plot([0, 1], [0, 1], color='navy', lw=5, linestyle='--') | |
# Labels and legend | |
plt.xlim([0.0, 1.0]) | |
plt.ylim([0.0, 1.05]) | |
plt.xlabel('False Positive Rate',fontsize=28,family="Times New Roman") | |
plt.ylabel('True Positive Rate',fontsize=28,family="Times New Roman") | |
plt.title('ROC Curve (One-vs-Rest) for Each Class',fontsize=30, family="Times New Roman",color="black", loc='center',pad=3) | |
plt.legend(loc='lower right',fontsize=18) | |
# Save the plot as an image | |
plt.savefig('roc_curve.png', transparent=True) | |
plt.close() | |
# Display the plot in Streamlit | |
with st.container(key="roc"): | |
st.image('roc_curve.png') | |
elif st.session_state.framework == "PyTorch": | |
# Display the ROC curve in Streamlit | |
with st.container(key="roc"): | |
st.image('roc-py.png') | |
with st.container(key="precision"): | |
st.image('precision_recall_curve.png') | |
if st.session_state.show_desc: | |
# components.html(html_string) # JavaScript works | |
# st.markdown(html_string, unsafe_allow_html=True) | |
image_path = "image.jpg" | |
st.container() | |
st.markdown( | |
f""" | |
<div class="titles"> | |
<h1>Kidney Disease Classfication</br> Using Deep learning</h1> | |
<div> This web application utilizes deep learning to classify kidney ultrasound images</br> | |
into four categories: Normal, Cyst, Tumor, and Stone Class. | |
Built with Streamlit and powered by </br>a TensorFlow transfer learning | |
model based on <strong>CNN</strong> | |
the app provides a simple and efficient way for users </br> | |
to upload kidney scans and receive instant predictions. The model analyzes the image | |
and classifies it based </br>on learned patterns, offering a confidence score for better interpretation. | |
</div> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
uploaded_file = st.file_uploader( | |
"Choose a file", type=["png", "jpg", "jpeg"], key="upload-btn" | |
) | |
if uploaded_file is not None: | |
images = Image.open(uploaded_file) | |
# Rewind file pointer to the beginning | |
uploaded_file.seek(0) | |
file_content = uploaded_file.read() # Read file once | |
# Convert to base64 for HTML display | |
encoded_image = base64.b64encode(file_content).decode() | |
# Read and process image | |
pil_image = Image.open(uploaded_file).convert("RGB").resize((224, 224)) | |
img_array = np.array(pil_image) | |
prediction = predict_image(images) | |
if st.session_state.framework == "TensorFlow": | |
max_index = int(np.argmax(prediction[0])) | |
print(f"max index:{max_index}") | |
max_score = prediction[0][max_index] | |
predicted_class = np.argmax(prediction[0]) | |
highlight_class = "highlight" # Special class for the highest confidence score | |
# Generate Grad-CAM | |
cam = generate_gradcam(pil_image, predicted_class) | |
# Create overlay | |
heatmap = cm.jet(cam)[..., :3] | |
heatmap = (heatmap * 255).astype(np.uint8) | |
overlayed_image = cv2.addWeighted(img_array, 0.6, heatmap, 0.4, 0) | |
# Convert to PIL | |
overlayed_pil = Image.fromarray(overlayed_image) | |
# Convert to base64 | |
orig_b64 = convert_image_to_base64(pil_image) | |
overlay_b64 = convert_image_to_base64(overlayed_pil) | |
content = f""" | |
<div class="content-container"> | |
<!-- Title --> | |
<!-- Recently Viewed Section --> | |
<div class="content-container3"> | |
<img src="data:image/png;base64,{orig_b64}" alt="Uploaded Image"> | |
</div> | |
<div class="content-container3"> | |
<img src="data:image/png;base64,{overlay_b64}" class="result-image"> | |
</div> | |
<div class="content-container4 {'highlight' if max_index == 0 else ''}"> | |
<h3>{class_labels[0]}</h3> | |
<p>T Score: {prediction[0][0]:.2f}</p> | |
</div> | |
<div class="content-container5 {'highlight' if max_index == 1 else ''}"> | |
<h3> {class_labels[1]}</h3> | |
<p>T Score: {prediction[0][1]:.2f}</p> | |
</div> | |
<div class="content-container6 {'highlight' if max_index == 2 else ''}"> | |
<h3> {class_labels[2]}</h3> | |
<p>T Score: {prediction[0][2]:.2f}</p> | |
</div> | |
<div class="content-container7 {'highlight' if max_index == 3 else ''}"> | |
<h3>{class_labels[3]}</h3> | |
<p>T Score: {prediction[0][3]:.2f}</p> | |
</div> | |
""" | |
elif st.session_state.framework == "PyTorch": | |
class0, class1,prediction = predict_image(images) | |
max_index = int(np.argmax(prediction[0])) | |
print(f"max index:{max_index}") | |
max_score = prediction[0][max_index] | |
predicted_class = np.argmax(prediction[0]) | |
print(f"predicted class is :{predicted_class}") | |
#cams = generate_gradcams(pil_image, predicted_class) | |
#heatmap = cm.jet(cams)[..., :3] | |
#heatmap = (heatmap * 255).astype(np.uint8) | |
#overlayed_image = cv2.addWeighted(img_array, 0.6, heatmap, 0.4, 0) | |
# Convert to PIL | |
#overlayed_pil = Image.fromarray(overlayed_image) | |
# Convert to base64 | |
orig_b64 = convert_image_to_base64(pil_image) | |
#overlay_b64 = convert_image_to_base64(overlayed_pil) | |
highlight_class = "highlight" # Special class for the highest confidence score | |
# Generate Grad-CAM | |
# Create overlay | |
orig_b64 = convert_image_to_base64(pil_image) | |
content = f""" | |
<div class="content-container"> | |
<!-- Title --> | |
<!-- Recently Viewed Section --> | |
<div class="content-container3"> | |
<img src="data:image/png;base64,{orig_b64}" alt="Uploaded Image"> | |
</div> | |
<div class="content-container4 {'highlight' if max_index == 0 else ''}"> | |
<h3>{class_labels[0]}</h3> | |
<p>T Score: {prediction[0][0]:.2f}</p> | |
</div> | |
<div class="content-container5 {'highlight' if max_index == 1 else ''}"> | |
<h3> {class_labels[1]}</h3> | |
<p>T Score: {prediction[0][1]:.2f}</p> | |
</div> | |
<div class="content-container6 {'highlight' if max_index == 2 else ''}"> | |
<h3> {class_labels[2]}</h3> | |
<p>T Score: {prediction[0][2]:.2f}</p> | |
</div> | |
<div class="content-container7 {'highlight' if max_index == 3 else ''}"> | |
<h3>{class_labels[3]}</h3> | |
<p>T Score: {prediction[0][3]:.2f}</p> | |
</div> | |
""" | |
# Render the content | |
placeholder = st.empty() # Create a placeholder | |
placeholder.markdown(loading_html, unsafe_allow_html=True) | |
time.sleep(5) # Wait for 5 seconds | |
placeholder.empty() | |
st.markdown(content, unsafe_allow_html=True) | |
else: | |
default_image_path = "image.jpg" | |
with open(image_path, "rb") as image_file: | |
encoded_image = base64.b64encode(image_file.read()).decode() | |
st.markdown( | |
f""" | |
<div class="content-container"> | |
<!-- Title --> | |
<!-- Recently Viewed Section --> | |
<div class="content-container3"> | |
<img src="data:image/png;base64,{encoded_image}" alt="Default Image"> | |
</div> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) |