cisemh's picture
Update app.py
7a7d2bf verified
import streamlit as st
import pandas as pd
import plotly.express as px
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# Set page config
st.set_page_config(page_title="ML Models Comparison Dashboard", layout="wide")
# Title and description
st.title("Machine Learning Models Comparison Dashboard")
st.write("Compare performance metrics of different ML models on the CIFAR-10 dataset")
# Pre-computed metrics
results = {
'Accuracy': {
'KNN': 0.331,
'Logistic Regression': 0.368,
'Random Forest': 0.466,
'Naive Bayes': 0.298,
'K-Means': 0.109,
'CNN': 0.694
},
'Precision': {
'KNN': 0.342,
'Logistic Regression': 0.387,
'Random Forest': 0.409,
'Naive Bayes': 0.295,
'K-Means': 0.271,
'CNN': 0.453
},
'Recall': {
'KNN': 0.345,
'Logistic Regression': 0.389,
'Random Forest': 0.412,
'Naive Bayes': 0.298,
'K-Means': 0.275,
'CNN': 0.456
},
'F1': {
'KNN': 0.343,
'Logistic Regression': 0.388,
'Random Forest': 0.410,
'Naive Bayes': 0.296,
'K-Means': 0.273,
'CNN': 0.454
}
}
# Confusion matrices data
confusion_matrices = {
'CNN': np.array([
[672, 27, 78, 21, 21, 3, 3, 13, 116, 46],
[20, 807, 3, 15, 8, 2, 7, 1, 22, 115],
[54, 4, 593, 82, 144, 36, 28, 32, 18, 9],
[17, 8, 73, 586, 100, 108, 38, 34, 12, 24],
[19, 0, 53, 69, 720, 14, 15, 90, 15, 5],
[5, 3, 89, 300, 58, 458, 6, 64, 9, 8],
[3, 9, 55, 122, 118, 13, 653, 6, 7, 14],
[17, 3, 30, 74, 70, 36, 0, 754, 1, 15],
[41, 24, 11, 20, 9, 3, 6, 8, 844, 34],
[20, 51, 4, 22, 9, 3, 6, 12, 25, 848]
]),
'K-Means': np.array([
[106, 109, 62, 41, 139, 81, 33, 185, 211, 33],
[97, 184, 92, 141, 96, 159, 106, 25, 42, 58],
[54, 69, 252, 160, 42, 46, 120, 109, 84, 64],
[83, 120, 146, 154, 18, 60, 121, 87, 61, 150],
[39, 49, 245, 219, 19, 51, 117, 73, 20, 168],
[72, 185, 163, 103, 30, 35, 88, 132, 38, 154],
[86, 94, 211, 212, 11, 28, 206, 30, 39, 83],
[111, 88, 205, 131, 54, 135, 58, 57, 22, 139],
[31, 167, 46, 28, 328, 195, 33, 91, 39, 42],
[131, 81, 83, 123, 141, 331, 19, 18, 37, 36]
]),
'KNN': np.array([
[565, 12, 107, 20, 52, 6, 25, 4, 205, 4],
[195, 244, 121, 61, 120, 28, 31, 4, 178, 18],
[150, 7, 463, 58, 201, 24, 48, 10, 39, 0],
[108, 10, 279, 243, 133, 92, 80, 17, 32, 6],
[100, 6, 282, 54, 443, 23, 38, 11, 43, 0],
[89, 4, 254, 178, 143, 208, 68, 10, 41, 5],
[49, 1, 317, 102, 260, 24, 227, 2, 18, 0],
[127, 16, 226, 76, 236, 40, 45, 192, 41, 1],
[181, 31, 56, 45, 52, 11, 8, 4, 607, 5],
[192, 81, 127, 80, 117, 27, 41, 14, 205, 116]
]),
'Logistic Regression': np.array([
[424, 51, 58, 50, 25, 41, 17, 56, 202, 76],
[72, 426, 35, 48, 26, 36, 48, 47, 78, 184],
[96, 34, 266, 90, 123, 94, 130, 84, 51, 32],
[43, 52, 115, 235, 72, 194, 131, 56, 43, 59],
[55, 35, 137, 80, 280, 97, 152, 112, 25, 27],
[41, 41, 103, 202, 90, 300, 77, 64, 45, 37],
[14, 47, 97, 147, 94, 94, 417, 42, 23, 25],
[47, 47, 91, 70, 97, 91, 47, 397, 45, 68],
[146, 85, 31, 37, 17, 41, 10, 18, 513, 102],
[78, 180, 26, 36, 30, 29, 45, 59, 98, 419]
]),
'Naive Bayes': np.array([
[494, 20, 39, 10, 84, 34, 50, 9, 200, 60],
[141, 166, 24, 31, 66, 72, 192, 19, 121, 168],
[225, 24, 83, 15, 292, 48, 209, 21, 54, 29],
[163, 36, 54, 76, 151, 129, 262, 26, 34, 69],
[86, 8, 57, 26, 417, 38, 265, 22, 50, 31],
[156, 17, 55, 51, 167, 264, 159, 36, 57, 38],
[106, 2, 60, 18, 228, 46, 467, 15, 19, 39],
[134, 24, 36, 41, 228, 94, 102, 131, 72, 138],
[168, 41, 18, 17, 56, 83, 39, 8, 471, 99],
[144, 67, 17, 20, 48, 32, 101, 23, 141, 407]
]),
'Random Forest': np.array([
[559, 36, 62, 21, 28, 20, 20, 30, 165, 59],
[29, 544, 10, 38, 22, 34, 45, 35, 67, 176],
[100, 36, 337, 79, 137, 69, 123, 54, 34, 31],
[53, 46, 76, 282, 88, 173, 132, 62, 23, 65],
[54, 21, 139, 60, 381, 47, 158, 91, 27, 22],
[32, 27, 89, 167, 74, 400, 77, 74, 27, 33],
[11, 33, 87, 78, 99, 53, 567, 31, 6, 35],
[47, 44, 53, 58, 104, 83, 38, 451, 24, 98],
[90, 83, 19, 33, 15, 39, 7, 22, 615, 77],
[50, 176, 18, 38, 20, 24, 27, 43, 80, 524]
])
}
# Create tabs for different visualizations
tab1, tab2, tab3 = st.tabs(["Overall Performance", "Confusion Matrices", "Individual Metrics"])
with tab1:
st.header("Overall Model Performance")
# Create bar plot for overall accuracy
fig, ax = plt.subplots(figsize=(12, 6))
models = list(results['Accuracy'].keys())
accuracies = list(results['Accuracy'].values())
colors = ['purple', 'navy', 'teal', 'green', 'lime', 'yellow']
bars = ax.bar(models, accuracies, color=colors)
# Customize the plot
ax.set_title('Overall Model Performance Comparison')
ax.set_xlabel('Models')
ax.set_ylabel('Accuracy')
plt.xticks(rotation=45)
# Add value labels on top of bars
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.3f}',
ha='center', va='bottom')
plt.tight_layout()
st.pyplot(fig)
with tab2:
st.header("Confusion Matrices")
# Model selection for confusion matrix
selected_model = st.selectbox("Select Model", list(confusion_matrices.keys()))
# Create confusion matrix plot using seaborn
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(confusion_matrices[selected_model],
annot=True,
fmt='d',
cmap='Blues',
ax=ax)
plt.title(f'Confusion Matrix - {selected_model}')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.tight_layout()
st.pyplot(fig)
with tab3:
st.header("Individual Metrics")
col1, col2, col3 = st.columns(3)
metrics = ['Precision', 'Recall', 'F1']
for metric, col in zip(metrics, [col1, col2, col3]):
with col:
fig, ax = plt.subplots(figsize=(8, 6))
models = list(results[metric].keys())
values = list(results[metric].values())
ax.bar(models, values)
ax.set_title(f'Comparison of {metric}')
plt.xticks(rotation=45, ha='right')
ax.set_ylabel(metric)
plt.tight_layout()
st.pyplot(fig)
# Add metrics table to sidebar
st.sidebar.header("Metrics Table")
df_metrics = pd.DataFrame(results)
st.sidebar.dataframe(df_metrics.style.format("{:.3f}"))
# Add download button
@st.cache_data
def convert_df_to_csv():
return df_metrics.to_csv()
csv = convert_df_to_csv()
st.sidebar.download_button(
label="Download metrics as CSV",
data=csv,
file_name='model_metrics.csv',
mime='text/csv',
)
# Footer
st.markdown("---")
st.markdown("Dashboard created for ML Models Comparison on CIFAR-10 dataset")