Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import plotly.express as px | |
import numpy as np | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
# Set page config | |
st.set_page_config(page_title="ML Models Comparison Dashboard", layout="wide") | |
# Title and description | |
st.title("Machine Learning Models Comparison Dashboard") | |
st.write("Compare performance metrics of different ML models on the CIFAR-10 dataset") | |
# Pre-computed metrics | |
results = { | |
'Accuracy': { | |
'KNN': 0.331, | |
'Logistic Regression': 0.368, | |
'Random Forest': 0.466, | |
'Naive Bayes': 0.298, | |
'K-Means': 0.109, | |
'CNN': 0.694 | |
}, | |
'Precision': { | |
'KNN': 0.342, | |
'Logistic Regression': 0.387, | |
'Random Forest': 0.409, | |
'Naive Bayes': 0.295, | |
'K-Means': 0.271, | |
'CNN': 0.453 | |
}, | |
'Recall': { | |
'KNN': 0.345, | |
'Logistic Regression': 0.389, | |
'Random Forest': 0.412, | |
'Naive Bayes': 0.298, | |
'K-Means': 0.275, | |
'CNN': 0.456 | |
}, | |
'F1': { | |
'KNN': 0.343, | |
'Logistic Regression': 0.388, | |
'Random Forest': 0.410, | |
'Naive Bayes': 0.296, | |
'K-Means': 0.273, | |
'CNN': 0.454 | |
} | |
} | |
# Confusion matrices data | |
confusion_matrices = { | |
'CNN': np.array([ | |
[672, 27, 78, 21, 21, 3, 3, 13, 116, 46], | |
[20, 807, 3, 15, 8, 2, 7, 1, 22, 115], | |
[54, 4, 593, 82, 144, 36, 28, 32, 18, 9], | |
[17, 8, 73, 586, 100, 108, 38, 34, 12, 24], | |
[19, 0, 53, 69, 720, 14, 15, 90, 15, 5], | |
[5, 3, 89, 300, 58, 458, 6, 64, 9, 8], | |
[3, 9, 55, 122, 118, 13, 653, 6, 7, 14], | |
[17, 3, 30, 74, 70, 36, 0, 754, 1, 15], | |
[41, 24, 11, 20, 9, 3, 6, 8, 844, 34], | |
[20, 51, 4, 22, 9, 3, 6, 12, 25, 848] | |
]), | |
'K-Means': np.array([ | |
[106, 109, 62, 41, 139, 81, 33, 185, 211, 33], | |
[97, 184, 92, 141, 96, 159, 106, 25, 42, 58], | |
[54, 69, 252, 160, 42, 46, 120, 109, 84, 64], | |
[83, 120, 146, 154, 18, 60, 121, 87, 61, 150], | |
[39, 49, 245, 219, 19, 51, 117, 73, 20, 168], | |
[72, 185, 163, 103, 30, 35, 88, 132, 38, 154], | |
[86, 94, 211, 212, 11, 28, 206, 30, 39, 83], | |
[111, 88, 205, 131, 54, 135, 58, 57, 22, 139], | |
[31, 167, 46, 28, 328, 195, 33, 91, 39, 42], | |
[131, 81, 83, 123, 141, 331, 19, 18, 37, 36] | |
]), | |
'KNN': np.array([ | |
[565, 12, 107, 20, 52, 6, 25, 4, 205, 4], | |
[195, 244, 121, 61, 120, 28, 31, 4, 178, 18], | |
[150, 7, 463, 58, 201, 24, 48, 10, 39, 0], | |
[108, 10, 279, 243, 133, 92, 80, 17, 32, 6], | |
[100, 6, 282, 54, 443, 23, 38, 11, 43, 0], | |
[89, 4, 254, 178, 143, 208, 68, 10, 41, 5], | |
[49, 1, 317, 102, 260, 24, 227, 2, 18, 0], | |
[127, 16, 226, 76, 236, 40, 45, 192, 41, 1], | |
[181, 31, 56, 45, 52, 11, 8, 4, 607, 5], | |
[192, 81, 127, 80, 117, 27, 41, 14, 205, 116] | |
]), | |
'Logistic Regression': np.array([ | |
[424, 51, 58, 50, 25, 41, 17, 56, 202, 76], | |
[72, 426, 35, 48, 26, 36, 48, 47, 78, 184], | |
[96, 34, 266, 90, 123, 94, 130, 84, 51, 32], | |
[43, 52, 115, 235, 72, 194, 131, 56, 43, 59], | |
[55, 35, 137, 80, 280, 97, 152, 112, 25, 27], | |
[41, 41, 103, 202, 90, 300, 77, 64, 45, 37], | |
[14, 47, 97, 147, 94, 94, 417, 42, 23, 25], | |
[47, 47, 91, 70, 97, 91, 47, 397, 45, 68], | |
[146, 85, 31, 37, 17, 41, 10, 18, 513, 102], | |
[78, 180, 26, 36, 30, 29, 45, 59, 98, 419] | |
]), | |
'Naive Bayes': np.array([ | |
[494, 20, 39, 10, 84, 34, 50, 9, 200, 60], | |
[141, 166, 24, 31, 66, 72, 192, 19, 121, 168], | |
[225, 24, 83, 15, 292, 48, 209, 21, 54, 29], | |
[163, 36, 54, 76, 151, 129, 262, 26, 34, 69], | |
[86, 8, 57, 26, 417, 38, 265, 22, 50, 31], | |
[156, 17, 55, 51, 167, 264, 159, 36, 57, 38], | |
[106, 2, 60, 18, 228, 46, 467, 15, 19, 39], | |
[134, 24, 36, 41, 228, 94, 102, 131, 72, 138], | |
[168, 41, 18, 17, 56, 83, 39, 8, 471, 99], | |
[144, 67, 17, 20, 48, 32, 101, 23, 141, 407] | |
]), | |
'Random Forest': np.array([ | |
[559, 36, 62, 21, 28, 20, 20, 30, 165, 59], | |
[29, 544, 10, 38, 22, 34, 45, 35, 67, 176], | |
[100, 36, 337, 79, 137, 69, 123, 54, 34, 31], | |
[53, 46, 76, 282, 88, 173, 132, 62, 23, 65], | |
[54, 21, 139, 60, 381, 47, 158, 91, 27, 22], | |
[32, 27, 89, 167, 74, 400, 77, 74, 27, 33], | |
[11, 33, 87, 78, 99, 53, 567, 31, 6, 35], | |
[47, 44, 53, 58, 104, 83, 38, 451, 24, 98], | |
[90, 83, 19, 33, 15, 39, 7, 22, 615, 77], | |
[50, 176, 18, 38, 20, 24, 27, 43, 80, 524] | |
]) | |
} | |
# Create tabs for different visualizations | |
tab1, tab2, tab3 = st.tabs(["Overall Performance", "Confusion Matrices", "Individual Metrics"]) | |
with tab1: | |
st.header("Overall Model Performance") | |
# Create bar plot for overall accuracy | |
fig, ax = plt.subplots(figsize=(12, 6)) | |
models = list(results['Accuracy'].keys()) | |
accuracies = list(results['Accuracy'].values()) | |
colors = ['purple', 'navy', 'teal', 'green', 'lime', 'yellow'] | |
bars = ax.bar(models, accuracies, color=colors) | |
# Customize the plot | |
ax.set_title('Overall Model Performance Comparison') | |
ax.set_xlabel('Models') | |
ax.set_ylabel('Accuracy') | |
plt.xticks(rotation=45) | |
# Add value labels on top of bars | |
for bar in bars: | |
height = bar.get_height() | |
ax.text(bar.get_x() + bar.get_width()/2., height, | |
f'{height:.3f}', | |
ha='center', va='bottom') | |
plt.tight_layout() | |
st.pyplot(fig) | |
with tab2: | |
st.header("Confusion Matrices") | |
# Model selection for confusion matrix | |
selected_model = st.selectbox("Select Model", list(confusion_matrices.keys())) | |
# Create confusion matrix plot using seaborn | |
fig, ax = plt.subplots(figsize=(10, 8)) | |
sns.heatmap(confusion_matrices[selected_model], | |
annot=True, | |
fmt='d', | |
cmap='Blues', | |
ax=ax) | |
plt.title(f'Confusion Matrix - {selected_model}') | |
plt.xlabel('Predicted') | |
plt.ylabel('True') | |
plt.tight_layout() | |
st.pyplot(fig) | |
with tab3: | |
st.header("Individual Metrics") | |
col1, col2, col3 = st.columns(3) | |
metrics = ['Precision', 'Recall', 'F1'] | |
for metric, col in zip(metrics, [col1, col2, col3]): | |
with col: | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
models = list(results[metric].keys()) | |
values = list(results[metric].values()) | |
ax.bar(models, values) | |
ax.set_title(f'Comparison of {metric}') | |
plt.xticks(rotation=45, ha='right') | |
ax.set_ylabel(metric) | |
plt.tight_layout() | |
st.pyplot(fig) | |
# Add metrics table to sidebar | |
st.sidebar.header("Metrics Table") | |
df_metrics = pd.DataFrame(results) | |
st.sidebar.dataframe(df_metrics.style.format("{:.3f}")) | |
# Add download button | |
def convert_df_to_csv(): | |
return df_metrics.to_csv() | |
csv = convert_df_to_csv() | |
st.sidebar.download_button( | |
label="Download metrics as CSV", | |
data=csv, | |
file_name='model_metrics.csv', | |
mime='text/csv', | |
) | |
# Footer | |
st.markdown("---") | |
st.markdown("Dashboard created for ML Models Comparison on CIFAR-10 dataset") |