Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
import numpy as np | |
# Set page config | |
st.set_page_config(page_title="ML Models Comparison Dashboard", layout="wide") | |
# Title and description | |
st.title("Machine Learning Models Comparison Dashboard") | |
st.write("Compare performance metrics of different ML models on the CIFAR-10 dataset") | |
# Pre-computed metrics (replace these with your actual results) | |
results = { | |
'Accuracy': { | |
'KNN': 0.345, # Replace with your actual values | |
'Logistic Regression': 0.389, | |
'Random Forest': 0.412, | |
'Naive Bayes': 0.298, | |
'K-Means': 0.275, | |
'CNN': 0.456 | |
}, | |
'Precision': { | |
'KNN': 0.342, | |
'Logistic Regression': 0.387, | |
'Random Forest': 0.409, | |
'Naive Bayes': 0.295, | |
'K-Means': 0.271, | |
'CNN': 0.453 | |
}, | |
'Recall': { | |
'KNN': 0.345, | |
'Logistic Regression': 0.389, | |
'Random Forest': 0.412, | |
'Naive Bayes': 0.298, | |
'K-Means': 0.275, | |
'CNN': 0.456 | |
}, | |
'F1': { | |
'KNN': 0.343, | |
'Logistic Regression': 0.388, | |
'Random Forest': 0.410, | |
'Naive Bayes': 0.296, | |
'K-Means': 0.273, | |
'CNN': 0.454 | |
} | |
} | |
# Pre-computed confusion matrices (replace these with your actual confusion matrices) | |
confusion_matrices = { | |
'KNN': np.random.randint(0, 100, (10, 10)), # Replace with actual confusion matrices | |
'Logistic Regression': np.random.randint(0, 100, (10, 10)), | |
'Random Forest': np.random.randint(0, 100, (10, 10)), | |
'Naive Bayes': np.random.randint(0, 100, (10, 10)), | |
'K-Means': np.random.randint(0, 100, (10, 10)), | |
'CNN': np.random.randint(0, 100, (10, 10)) | |
} | |
# Create tabs for different visualizations | |
tab1, tab2, tab3 = st.tabs(["Metrics Comparison", "Confusion Matrices", "Radar Plot"]) | |
with tab1: | |
st.header("Performance Metrics Comparison") | |
# Convert results to DataFrame for plotting | |
df_metrics = pd.DataFrame(results) | |
df_metrics.index.name = 'Model' | |
df_metrics = df_metrics.reset_index() | |
# Create bar plot using plotly | |
fig = px.bar(df_metrics.melt(id_vars=['Model'], | |
var_name='Metric', | |
value_name='Score'), | |
x='Model', y='Score', color='Metric', barmode='group', | |
title='Model Performance Comparison') | |
fig.update_layout(xaxis_tickangle=-45) | |
st.plotly_chart(fig) | |
# Display metrics table | |
st.subheader("Metrics Table") | |
st.dataframe(df_metrics.set_index('Model').style.format("{:.3f}")) | |
with tab2: | |
st.header("Confusion Matrices") | |
# Select model for confusion matrix | |
selected_model = st.selectbox("Select Model", list(confusion_matrices.keys())) | |
# Plot confusion matrix using plotly | |
fig = px.imshow(confusion_matrices[selected_model], | |
labels=dict(x="Predicted", y="True"), | |
title=f"Confusion Matrix - {selected_model}") | |
st.plotly_chart(fig) | |
with tab3: | |
st.header("Radar Plot Comparison") | |
# Create radar plot using plotly | |
fig = go.Figure() | |
metrics = list(results.keys()) | |
models = list(results['Accuracy'].keys()) | |
for model in models: | |
values = [results[metric][model] for metric in metrics] | |
values.append(values[0]) # Complete the circle | |
fig.add_trace(go.Scatterpolar( | |
r=values, | |
theta=metrics + [metrics[0]], | |
name=model | |
)) | |
fig.update_layout( | |
polar=dict(radialaxis=dict(visible=True, range=[0, 1])), | |
showlegend=True, | |
title="Model Comparison - All Metrics" | |
) | |
st.plotly_chart(fig) | |
# Add download button for metrics | |
def convert_df_to_csv(): | |
return df_metrics.to_csv(index=False) | |
st.sidebar.header("Download Data") | |
csv = convert_df_to_csv() | |
st.sidebar.download_button( | |
label="Download metrics as CSV", | |
data=csv, | |
file_name='model_metrics.csv', | |
mime='text/csv', | |
) | |
# Add explanatory text | |
st.sidebar.markdown(""" | |
### Dashboard Features: | |
1. View pre-computed metrics for all models | |
2. Compare performance across different metrics | |
3. Examine confusion matrices | |
4. Download metrics data as CSV | |
""") | |
# Footer | |
st.markdown("---") | |
st.markdown("Dashboard created with Streamlit for ML Models Comparison") |