Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import plotly.express as px
|
4 |
-
import plotly.graph_objects as go
|
5 |
import numpy as np
|
|
|
|
|
6 |
|
7 |
# Set page config
|
8 |
st.set_page_config(page_title="ML Models Comparison Dashboard", layout="wide")
|
@@ -11,15 +12,15 @@ st.set_page_config(page_title="ML Models Comparison Dashboard", layout="wide")
|
|
11 |
st.title("Machine Learning Models Comparison Dashboard")
|
12 |
st.write("Compare performance metrics of different ML models on the CIFAR-10 dataset")
|
13 |
|
14 |
-
# Pre-computed metrics
|
15 |
results = {
|
16 |
'Accuracy': {
|
17 |
-
'KNN': 0.
|
18 |
-
'Logistic Regression': 0.
|
19 |
-
'Random Forest': 0.
|
20 |
'Naive Bayes': 0.298,
|
21 |
-
'K-Means': 0.
|
22 |
-
'CNN': 0.
|
23 |
},
|
24 |
'Precision': {
|
25 |
'KNN': 0.342,
|
@@ -47,84 +48,163 @@ results = {
|
|
47 |
}
|
48 |
}
|
49 |
|
50 |
-
#
|
51 |
confusion_matrices = {
|
52 |
-
'
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
}
|
59 |
|
60 |
# Create tabs for different visualizations
|
61 |
-
tab1, tab2, tab3 = st.tabs(["
|
62 |
|
63 |
with tab1:
|
64 |
-
st.header("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
#
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
70 |
|
71 |
-
#
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
fig.update_layout(xaxis_tickangle=-45)
|
78 |
-
st.plotly_chart(fig)
|
79 |
|
80 |
-
|
81 |
-
st.
|
82 |
-
st.dataframe(df_metrics.set_index('Model').style.format("{:.3f}"))
|
83 |
|
84 |
with tab2:
|
85 |
st.header("Confusion Matrices")
|
86 |
|
87 |
-
#
|
88 |
selected_model = st.selectbox("Select Model", list(confusion_matrices.keys()))
|
89 |
|
90 |
-
#
|
91 |
-
fig =
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
st.header("Radar Plot Comparison")
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
fig.add_trace(go.Scatterpolar(
|
109 |
-
r=values,
|
110 |
-
theta=metrics + [metrics[0]],
|
111 |
-
name=model
|
112 |
-
))
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
showlegend=True,
|
117 |
-
title="Model Comparison - All Metrics"
|
118 |
-
)
|
119 |
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
-
# Add download button
|
123 |
@st.cache_data
|
124 |
def convert_df_to_csv():
|
125 |
-
return df_metrics.to_csv(
|
126 |
|
127 |
-
st.sidebar.header("Download Data")
|
128 |
csv = convert_df_to_csv()
|
129 |
st.sidebar.download_button(
|
130 |
label="Download metrics as CSV",
|
@@ -133,15 +213,6 @@ st.sidebar.download_button(
|
|
133 |
mime='text/csv',
|
134 |
)
|
135 |
|
136 |
-
# Add explanatory text
|
137 |
-
st.sidebar.markdown("""
|
138 |
-
### Dashboard Features:
|
139 |
-
1. View pre-computed metrics for all models
|
140 |
-
2. Compare performance across different metrics
|
141 |
-
3. Examine confusion matrices
|
142 |
-
4. Download metrics data as CSV
|
143 |
-
""")
|
144 |
-
|
145 |
# Footer
|
146 |
st.markdown("---")
|
147 |
-
st.markdown("Dashboard created
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import plotly.express as px
|
|
|
4 |
import numpy as np
|
5 |
+
import seaborn as sns
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
|
8 |
# Set page config
|
9 |
st.set_page_config(page_title="ML Models Comparison Dashboard", layout="wide")
|
|
|
12 |
st.title("Machine Learning Models Comparison Dashboard")
|
13 |
st.write("Compare performance metrics of different ML models on the CIFAR-10 dataset")
|
14 |
|
15 |
+
# Pre-computed metrics
|
16 |
results = {
|
17 |
'Accuracy': {
|
18 |
+
'KNN': 0.331,
|
19 |
+
'Logistic Regression': 0.368,
|
20 |
+
'Random Forest': 0.466,
|
21 |
'Naive Bayes': 0.298,
|
22 |
+
'K-Means': 0.109,
|
23 |
+
'CNN': 0.694
|
24 |
},
|
25 |
'Precision': {
|
26 |
'KNN': 0.342,
|
|
|
48 |
}
|
49 |
}
|
50 |
|
51 |
+
# Confusion matrices data
|
52 |
confusion_matrices = {
|
53 |
+
'CNN': np.array([
|
54 |
+
[672, 27, 78, 21, 21, 3, 3, 13, 116, 46],
|
55 |
+
[20, 807, 3, 15, 8, 2, 7, 1, 22, 115],
|
56 |
+
[54, 4, 593, 82, 144, 36, 28, 32, 18, 9],
|
57 |
+
[17, 8, 73, 586, 100, 108, 38, 34, 12, 24],
|
58 |
+
[19, 0, 53, 69, 720, 14, 15, 90, 15, 5],
|
59 |
+
[5, 3, 89, 300, 58, 458, 6, 64, 9, 8],
|
60 |
+
[3, 9, 55, 122, 118, 13, 653, 6, 7, 14],
|
61 |
+
[17, 3, 30, 74, 70, 36, 0, 754, 1, 15],
|
62 |
+
[41, 24, 11, 20, 9, 3, 6, 8, 844, 34],
|
63 |
+
[20, 51, 4, 22, 9, 3, 6, 12, 25, 848]
|
64 |
+
]),
|
65 |
+
'K-Means': np.array([
|
66 |
+
[106, 109, 62, 41, 139, 81, 33, 185, 211, 33],
|
67 |
+
[97, 184, 92, 141, 96, 159, 106, 25, 42, 58],
|
68 |
+
[54, 69, 252, 160, 42, 46, 120, 109, 84, 64],
|
69 |
+
[83, 120, 146, 154, 18, 60, 121, 87, 61, 150],
|
70 |
+
[39, 49, 245, 219, 19, 51, 117, 73, 20, 168],
|
71 |
+
[72, 185, 163, 103, 30, 35, 88, 132, 38, 154],
|
72 |
+
[86, 94, 211, 212, 11, 28, 206, 30, 39, 83],
|
73 |
+
[111, 88, 205, 131, 54, 135, 58, 57, 22, 139],
|
74 |
+
[31, 167, 46, 28, 328, 195, 33, 91, 39, 42],
|
75 |
+
[131, 81, 83, 123, 141, 331, 19, 18, 37, 36]
|
76 |
+
]),
|
77 |
+
'KNN': np.array([
|
78 |
+
[565, 12, 107, 20, 52, 6, 25, 4, 205, 4],
|
79 |
+
[195, 244, 121, 61, 120, 28, 31, 4, 178, 18],
|
80 |
+
[150, 7, 463, 58, 201, 24, 48, 10, 39, 0],
|
81 |
+
[108, 10, 279, 243, 133, 92, 80, 17, 32, 6],
|
82 |
+
[100, 6, 282, 54, 443, 23, 38, 11, 43, 0],
|
83 |
+
[89, 4, 254, 178, 143, 208, 68, 10, 41, 5],
|
84 |
+
[49, 1, 317, 102, 260, 24, 227, 2, 18, 0],
|
85 |
+
[127, 16, 226, 76, 236, 40, 45, 192, 41, 1],
|
86 |
+
[181, 31, 56, 45, 52, 11, 8, 4, 607, 5],
|
87 |
+
[192, 81, 127, 80, 117, 27, 41, 14, 205, 116]
|
88 |
+
]),
|
89 |
+
'Logistic Regression': np.array([
|
90 |
+
[424, 51, 58, 50, 25, 41, 17, 56, 202, 76],
|
91 |
+
[72, 426, 35, 48, 26, 36, 48, 47, 78, 184],
|
92 |
+
[96, 34, 266, 90, 123, 94, 130, 84, 51, 32],
|
93 |
+
[43, 52, 115, 235, 72, 194, 131, 56, 43, 59],
|
94 |
+
[55, 35, 137, 80, 280, 97, 152, 112, 25, 27],
|
95 |
+
[41, 41, 103, 202, 90, 300, 77, 64, 45, 37],
|
96 |
+
[14, 47, 97, 147, 94, 94, 417, 42, 23, 25],
|
97 |
+
[47, 47, 91, 70, 97, 91, 47, 397, 45, 68],
|
98 |
+
[146, 85, 31, 37, 17, 41, 10, 18, 513, 102],
|
99 |
+
[78, 180, 26, 36, 30, 29, 45, 59, 98, 419]
|
100 |
+
]),
|
101 |
+
'Naive Bayes': np.array([
|
102 |
+
[494, 20, 39, 10, 84, 34, 50, 9, 200, 60],
|
103 |
+
[141, 166, 24, 31, 66, 72, 192, 19, 121, 168],
|
104 |
+
[225, 24, 83, 15, 292, 48, 209, 21, 54, 29],
|
105 |
+
[163, 36, 54, 76, 151, 129, 262, 26, 34, 69],
|
106 |
+
[86, 8, 57, 26, 417, 38, 265, 22, 50, 31],
|
107 |
+
[156, 17, 55, 51, 167, 264, 159, 36, 57, 38],
|
108 |
+
[106, 2, 60, 18, 228, 46, 467, 15, 19, 39],
|
109 |
+
[134, 24, 36, 41, 228, 94, 102, 131, 72, 138],
|
110 |
+
[168, 41, 18, 17, 56, 83, 39, 8, 471, 99],
|
111 |
+
[144, 67, 17, 20, 48, 32, 101, 23, 141, 407]
|
112 |
+
]),
|
113 |
+
'Random Forest': np.array([
|
114 |
+
[559, 36, 62, 21, 28, 20, 20, 30, 165, 59],
|
115 |
+
[29, 544, 10, 38, 22, 34, 45, 35, 67, 176],
|
116 |
+
[100, 36, 337, 79, 137, 69, 123, 54, 34, 31],
|
117 |
+
[53, 46, 76, 282, 88, 173, 132, 62, 23, 65],
|
118 |
+
[54, 21, 139, 60, 381, 47, 158, 91, 27, 22],
|
119 |
+
[32, 27, 89, 167, 74, 400, 77, 74, 27, 33],
|
120 |
+
[11, 33, 87, 78, 99, 53, 567, 31, 6, 35],
|
121 |
+
[47, 44, 53, 58, 104, 83, 38, 451, 24, 98],
|
122 |
+
[90, 83, 19, 33, 15, 39, 7, 22, 615, 77],
|
123 |
+
[50, 176, 18, 38, 20, 24, 27, 43, 80, 524]
|
124 |
+
])
|
125 |
}
|
126 |
|
127 |
# Create tabs for different visualizations
|
128 |
+
tab1, tab2, tab3 = st.tabs(["Overall Performance", "Confusion Matrices", "Individual Metrics"])
|
129 |
|
130 |
with tab1:
|
131 |
+
st.header("Overall Model Performance")
|
132 |
+
|
133 |
+
# Create bar plot for overall accuracy
|
134 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
135 |
+
models = list(results['Accuracy'].keys())
|
136 |
+
accuracies = list(results['Accuracy'].values())
|
137 |
+
|
138 |
+
colors = ['purple', 'navy', 'teal', 'green', 'lime', 'yellow']
|
139 |
+
bars = ax.bar(models, accuracies, color=colors)
|
140 |
|
141 |
+
# Customize the plot
|
142 |
+
ax.set_title('Overall Model Performance Comparison')
|
143 |
+
ax.set_xlabel('Models')
|
144 |
+
ax.set_ylabel('Accuracy')
|
145 |
+
plt.xticks(rotation=45)
|
146 |
|
147 |
+
# Add value labels on top of bars
|
148 |
+
for bar in bars:
|
149 |
+
height = bar.get_height()
|
150 |
+
ax.text(bar.get_x() + bar.get_width()/2., height,
|
151 |
+
f'{height:.3f}',
|
152 |
+
ha='center', va='bottom')
|
|
|
|
|
153 |
|
154 |
+
plt.tight_layout()
|
155 |
+
st.pyplot(fig)
|
|
|
156 |
|
157 |
with tab2:
|
158 |
st.header("Confusion Matrices")
|
159 |
|
160 |
+
# Model selection for confusion matrix
|
161 |
selected_model = st.selectbox("Select Model", list(confusion_matrices.keys()))
|
162 |
|
163 |
+
# Create confusion matrix plot using seaborn
|
164 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
165 |
+
sns.heatmap(confusion_matrices[selected_model],
|
166 |
+
annot=True,
|
167 |
+
fmt='d',
|
168 |
+
cmap='Blues',
|
169 |
+
ax=ax)
|
|
|
170 |
|
171 |
+
plt.title(f'Confusion Matrix - {selected_model}')
|
172 |
+
plt.xlabel('Predicted')
|
173 |
+
plt.ylabel('True')
|
174 |
+
plt.tight_layout()
|
175 |
|
176 |
+
st.pyplot(fig)
|
177 |
+
|
178 |
+
with tab3:
|
179 |
+
st.header("Individual Metrics")
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
+
col1, col2, col3 = st.columns(3)
|
182 |
+
metrics = ['Precision', 'Recall', 'F1']
|
|
|
|
|
|
|
183 |
|
184 |
+
for metric, col in zip(metrics, [col1, col2, col3]):
|
185 |
+
with col:
|
186 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
187 |
+
models = list(results[metric].keys())
|
188 |
+
values = list(results[metric].values())
|
189 |
+
|
190 |
+
ax.bar(models, values)
|
191 |
+
ax.set_title(f'Comparison of {metric}')
|
192 |
+
plt.xticks(rotation=45, ha='right')
|
193 |
+
ax.set_ylabel(metric)
|
194 |
+
|
195 |
+
plt.tight_layout()
|
196 |
+
st.pyplot(fig)
|
197 |
+
|
198 |
+
# Add metrics table to sidebar
|
199 |
+
st.sidebar.header("Metrics Table")
|
200 |
+
df_metrics = pd.DataFrame(results)
|
201 |
+
st.sidebar.dataframe(df_metrics.style.format("{:.3f}"))
|
202 |
|
203 |
+
# Add download button
|
204 |
@st.cache_data
|
205 |
def convert_df_to_csv():
|
206 |
+
return df_metrics.to_csv()
|
207 |
|
|
|
208 |
csv = convert_df_to_csv()
|
209 |
st.sidebar.download_button(
|
210 |
label="Download metrics as CSV",
|
|
|
213 |
mime='text/csv',
|
214 |
)
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
# Footer
|
217 |
st.markdown("---")
|
218 |
+
st.markdown("Dashboard created for ML Models Comparison on CIFAR-10 dataset")
|