cisemh's picture
Update app.py
7a7d2bf verified
raw
history blame
7.3 kB
import streamlit as st
import pandas as pd
import plotly.express as px
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# Set page config
st.set_page_config(page_title="ML Models Comparison Dashboard", layout="wide")
# Title and description
st.title("Machine Learning Models Comparison Dashboard")
st.write("Compare performance metrics of different ML models on the CIFAR-10 dataset")
# Pre-computed metrics
results = {
'Accuracy': {
'KNN': 0.331,
'Logistic Regression': 0.368,
'Random Forest': 0.466,
'Naive Bayes': 0.298,
'K-Means': 0.109,
'CNN': 0.694
},
'Precision': {
'KNN': 0.342,
'Logistic Regression': 0.387,
'Random Forest': 0.409,
'Naive Bayes': 0.295,
'K-Means': 0.271,
'CNN': 0.453
},
'Recall': {
'KNN': 0.345,
'Logistic Regression': 0.389,
'Random Forest': 0.412,
'Naive Bayes': 0.298,
'K-Means': 0.275,
'CNN': 0.456
},
'F1': {
'KNN': 0.343,
'Logistic Regression': 0.388,
'Random Forest': 0.410,
'Naive Bayes': 0.296,
'K-Means': 0.273,
'CNN': 0.454
}
}
# Confusion matrices data
confusion_matrices = {
'CNN': np.array([
[672, 27, 78, 21, 21, 3, 3, 13, 116, 46],
[20, 807, 3, 15, 8, 2, 7, 1, 22, 115],
[54, 4, 593, 82, 144, 36, 28, 32, 18, 9],
[17, 8, 73, 586, 100, 108, 38, 34, 12, 24],
[19, 0, 53, 69, 720, 14, 15, 90, 15, 5],
[5, 3, 89, 300, 58, 458, 6, 64, 9, 8],
[3, 9, 55, 122, 118, 13, 653, 6, 7, 14],
[17, 3, 30, 74, 70, 36, 0, 754, 1, 15],
[41, 24, 11, 20, 9, 3, 6, 8, 844, 34],
[20, 51, 4, 22, 9, 3, 6, 12, 25, 848]
]),
'K-Means': np.array([
[106, 109, 62, 41, 139, 81, 33, 185, 211, 33],
[97, 184, 92, 141, 96, 159, 106, 25, 42, 58],
[54, 69, 252, 160, 42, 46, 120, 109, 84, 64],
[83, 120, 146, 154, 18, 60, 121, 87, 61, 150],
[39, 49, 245, 219, 19, 51, 117, 73, 20, 168],
[72, 185, 163, 103, 30, 35, 88, 132, 38, 154],
[86, 94, 211, 212, 11, 28, 206, 30, 39, 83],
[111, 88, 205, 131, 54, 135, 58, 57, 22, 139],
[31, 167, 46, 28, 328, 195, 33, 91, 39, 42],
[131, 81, 83, 123, 141, 331, 19, 18, 37, 36]
]),
'KNN': np.array([
[565, 12, 107, 20, 52, 6, 25, 4, 205, 4],
[195, 244, 121, 61, 120, 28, 31, 4, 178, 18],
[150, 7, 463, 58, 201, 24, 48, 10, 39, 0],
[108, 10, 279, 243, 133, 92, 80, 17, 32, 6],
[100, 6, 282, 54, 443, 23, 38, 11, 43, 0],
[89, 4, 254, 178, 143, 208, 68, 10, 41, 5],
[49, 1, 317, 102, 260, 24, 227, 2, 18, 0],
[127, 16, 226, 76, 236, 40, 45, 192, 41, 1],
[181, 31, 56, 45, 52, 11, 8, 4, 607, 5],
[192, 81, 127, 80, 117, 27, 41, 14, 205, 116]
]),
'Logistic Regression': np.array([
[424, 51, 58, 50, 25, 41, 17, 56, 202, 76],
[72, 426, 35, 48, 26, 36, 48, 47, 78, 184],
[96, 34, 266, 90, 123, 94, 130, 84, 51, 32],
[43, 52, 115, 235, 72, 194, 131, 56, 43, 59],
[55, 35, 137, 80, 280, 97, 152, 112, 25, 27],
[41, 41, 103, 202, 90, 300, 77, 64, 45, 37],
[14, 47, 97, 147, 94, 94, 417, 42, 23, 25],
[47, 47, 91, 70, 97, 91, 47, 397, 45, 68],
[146, 85, 31, 37, 17, 41, 10, 18, 513, 102],
[78, 180, 26, 36, 30, 29, 45, 59, 98, 419]
]),
'Naive Bayes': np.array([
[494, 20, 39, 10, 84, 34, 50, 9, 200, 60],
[141, 166, 24, 31, 66, 72, 192, 19, 121, 168],
[225, 24, 83, 15, 292, 48, 209, 21, 54, 29],
[163, 36, 54, 76, 151, 129, 262, 26, 34, 69],
[86, 8, 57, 26, 417, 38, 265, 22, 50, 31],
[156, 17, 55, 51, 167, 264, 159, 36, 57, 38],
[106, 2, 60, 18, 228, 46, 467, 15, 19, 39],
[134, 24, 36, 41, 228, 94, 102, 131, 72, 138],
[168, 41, 18, 17, 56, 83, 39, 8, 471, 99],
[144, 67, 17, 20, 48, 32, 101, 23, 141, 407]
]),
'Random Forest': np.array([
[559, 36, 62, 21, 28, 20, 20, 30, 165, 59],
[29, 544, 10, 38, 22, 34, 45, 35, 67, 176],
[100, 36, 337, 79, 137, 69, 123, 54, 34, 31],
[53, 46, 76, 282, 88, 173, 132, 62, 23, 65],
[54, 21, 139, 60, 381, 47, 158, 91, 27, 22],
[32, 27, 89, 167, 74, 400, 77, 74, 27, 33],
[11, 33, 87, 78, 99, 53, 567, 31, 6, 35],
[47, 44, 53, 58, 104, 83, 38, 451, 24, 98],
[90, 83, 19, 33, 15, 39, 7, 22, 615, 77],
[50, 176, 18, 38, 20, 24, 27, 43, 80, 524]
])
}
# Create tabs for different visualizations
tab1, tab2, tab3 = st.tabs(["Overall Performance", "Confusion Matrices", "Individual Metrics"])
with tab1:
st.header("Overall Model Performance")
# Create bar plot for overall accuracy
fig, ax = plt.subplots(figsize=(12, 6))
models = list(results['Accuracy'].keys())
accuracies = list(results['Accuracy'].values())
colors = ['purple', 'navy', 'teal', 'green', 'lime', 'yellow']
bars = ax.bar(models, accuracies, color=colors)
# Customize the plot
ax.set_title('Overall Model Performance Comparison')
ax.set_xlabel('Models')
ax.set_ylabel('Accuracy')
plt.xticks(rotation=45)
# Add value labels on top of bars
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.3f}',
ha='center', va='bottom')
plt.tight_layout()
st.pyplot(fig)
with tab2:
st.header("Confusion Matrices")
# Model selection for confusion matrix
selected_model = st.selectbox("Select Model", list(confusion_matrices.keys()))
# Create confusion matrix plot using seaborn
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(confusion_matrices[selected_model],
annot=True,
fmt='d',
cmap='Blues',
ax=ax)
plt.title(f'Confusion Matrix - {selected_model}')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.tight_layout()
st.pyplot(fig)
with tab3:
st.header("Individual Metrics")
col1, col2, col3 = st.columns(3)
metrics = ['Precision', 'Recall', 'F1']
for metric, col in zip(metrics, [col1, col2, col3]):
with col:
fig, ax = plt.subplots(figsize=(8, 6))
models = list(results[metric].keys())
values = list(results[metric].values())
ax.bar(models, values)
ax.set_title(f'Comparison of {metric}')
plt.xticks(rotation=45, ha='right')
ax.set_ylabel(metric)
plt.tight_layout()
st.pyplot(fig)
# Add metrics table to sidebar
st.sidebar.header("Metrics Table")
df_metrics = pd.DataFrame(results)
st.sidebar.dataframe(df_metrics.style.format("{:.3f}"))
# Add download button
@st.cache_data
def convert_df_to_csv():
return df_metrics.to_csv()
csv = convert_df_to_csv()
st.sidebar.download_button(
label="Download metrics as CSV",
data=csv,
file_name='model_metrics.csv',
mime='text/csv',
)
# Footer
st.markdown("---")
st.markdown("Dashboard created for ML Models Comparison on CIFAR-10 dataset")