cisemh commited on
Commit
7a7d2bf
·
verified ·
1 Parent(s): e76d46c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -72
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import plotly.express as px
4
- import plotly.graph_objects as go
5
  import numpy as np
 
 
6
 
7
  # Set page config
8
  st.set_page_config(page_title="ML Models Comparison Dashboard", layout="wide")
@@ -11,15 +12,15 @@ st.set_page_config(page_title="ML Models Comparison Dashboard", layout="wide")
11
  st.title("Machine Learning Models Comparison Dashboard")
12
  st.write("Compare performance metrics of different ML models on the CIFAR-10 dataset")
13
 
14
- # Pre-computed metrics (replace these with your actual results)
15
  results = {
16
  'Accuracy': {
17
- 'KNN': 0.345, # Replace with your actual values
18
- 'Logistic Regression': 0.389,
19
- 'Random Forest': 0.412,
20
  'Naive Bayes': 0.298,
21
- 'K-Means': 0.275,
22
- 'CNN': 0.456
23
  },
24
  'Precision': {
25
  'KNN': 0.342,
@@ -47,84 +48,163 @@ results = {
47
  }
48
  }
49
 
50
- # Pre-computed confusion matrices (replace these with your actual confusion matrices)
51
  confusion_matrices = {
52
- 'KNN': np.random.randint(0, 100, (10, 10)), # Replace with actual confusion matrices
53
- 'Logistic Regression': np.random.randint(0, 100, (10, 10)),
54
- 'Random Forest': np.random.randint(0, 100, (10, 10)),
55
- 'Naive Bayes': np.random.randint(0, 100, (10, 10)),
56
- 'K-Means': np.random.randint(0, 100, (10, 10)),
57
- 'CNN': np.random.randint(0, 100, (10, 10))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  }
59
 
60
  # Create tabs for different visualizations
61
- tab1, tab2, tab3 = st.tabs(["Metrics Comparison", "Confusion Matrices", "Radar Plot"])
62
 
63
  with tab1:
64
- st.header("Performance Metrics Comparison")
 
 
 
 
 
 
 
 
65
 
66
- # Convert results to DataFrame for plotting
67
- df_metrics = pd.DataFrame(results)
68
- df_metrics.index.name = 'Model'
69
- df_metrics = df_metrics.reset_index()
 
70
 
71
- # Create bar plot using plotly
72
- fig = px.bar(df_metrics.melt(id_vars=['Model'],
73
- var_name='Metric',
74
- value_name='Score'),
75
- x='Model', y='Score', color='Metric', barmode='group',
76
- title='Model Performance Comparison')
77
- fig.update_layout(xaxis_tickangle=-45)
78
- st.plotly_chart(fig)
79
 
80
- # Display metrics table
81
- st.subheader("Metrics Table")
82
- st.dataframe(df_metrics.set_index('Model').style.format("{:.3f}"))
83
 
84
  with tab2:
85
  st.header("Confusion Matrices")
86
 
87
- # Select model for confusion matrix
88
  selected_model = st.selectbox("Select Model", list(confusion_matrices.keys()))
89
 
90
- # Plot confusion matrix using plotly
91
- fig = px.imshow(confusion_matrices[selected_model],
92
- labels=dict(x="Predicted", y="True"),
93
- title=f"Confusion Matrix - {selected_model}")
94
- st.plotly_chart(fig)
95
-
96
- with tab3:
97
- st.header("Radar Plot Comparison")
98
 
99
- # Create radar plot using plotly
100
- fig = go.Figure()
101
- metrics = list(results.keys())
102
- models = list(results['Accuracy'].keys())
103
 
104
- for model in models:
105
- values = [results[metric][model] for metric in metrics]
106
- values.append(values[0]) # Complete the circle
107
-
108
- fig.add_trace(go.Scatterpolar(
109
- r=values,
110
- theta=metrics + [metrics[0]],
111
- name=model
112
- ))
113
 
114
- fig.update_layout(
115
- polar=dict(radialaxis=dict(visible=True, range=[0, 1])),
116
- showlegend=True,
117
- title="Model Comparison - All Metrics"
118
- )
119
 
120
- st.plotly_chart(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- # Add download button for metrics
123
  @st.cache_data
124
  def convert_df_to_csv():
125
- return df_metrics.to_csv(index=False)
126
 
127
- st.sidebar.header("Download Data")
128
  csv = convert_df_to_csv()
129
  st.sidebar.download_button(
130
  label="Download metrics as CSV",
@@ -133,15 +213,6 @@ st.sidebar.download_button(
133
  mime='text/csv',
134
  )
135
 
136
- # Add explanatory text
137
- st.sidebar.markdown("""
138
- ### Dashboard Features:
139
- 1. View pre-computed metrics for all models
140
- 2. Compare performance across different metrics
141
- 3. Examine confusion matrices
142
- 4. Download metrics data as CSV
143
- """)
144
-
145
  # Footer
146
  st.markdown("---")
147
- st.markdown("Dashboard created with Streamlit for ML Models Comparison")
 
1
  import streamlit as st
2
  import pandas as pd
3
  import plotly.express as px
 
4
  import numpy as np
5
+ import seaborn as sns
6
+ import matplotlib.pyplot as plt
7
 
8
  # Set page config
9
  st.set_page_config(page_title="ML Models Comparison Dashboard", layout="wide")
 
12
  st.title("Machine Learning Models Comparison Dashboard")
13
  st.write("Compare performance metrics of different ML models on the CIFAR-10 dataset")
14
 
15
+ # Pre-computed metrics
16
  results = {
17
  'Accuracy': {
18
+ 'KNN': 0.331,
19
+ 'Logistic Regression': 0.368,
20
+ 'Random Forest': 0.466,
21
  'Naive Bayes': 0.298,
22
+ 'K-Means': 0.109,
23
+ 'CNN': 0.694
24
  },
25
  'Precision': {
26
  'KNN': 0.342,
 
48
  }
49
  }
50
 
51
+ # Confusion matrices data
52
  confusion_matrices = {
53
+ 'CNN': np.array([
54
+ [672, 27, 78, 21, 21, 3, 3, 13, 116, 46],
55
+ [20, 807, 3, 15, 8, 2, 7, 1, 22, 115],
56
+ [54, 4, 593, 82, 144, 36, 28, 32, 18, 9],
57
+ [17, 8, 73, 586, 100, 108, 38, 34, 12, 24],
58
+ [19, 0, 53, 69, 720, 14, 15, 90, 15, 5],
59
+ [5, 3, 89, 300, 58, 458, 6, 64, 9, 8],
60
+ [3, 9, 55, 122, 118, 13, 653, 6, 7, 14],
61
+ [17, 3, 30, 74, 70, 36, 0, 754, 1, 15],
62
+ [41, 24, 11, 20, 9, 3, 6, 8, 844, 34],
63
+ [20, 51, 4, 22, 9, 3, 6, 12, 25, 848]
64
+ ]),
65
+ 'K-Means': np.array([
66
+ [106, 109, 62, 41, 139, 81, 33, 185, 211, 33],
67
+ [97, 184, 92, 141, 96, 159, 106, 25, 42, 58],
68
+ [54, 69, 252, 160, 42, 46, 120, 109, 84, 64],
69
+ [83, 120, 146, 154, 18, 60, 121, 87, 61, 150],
70
+ [39, 49, 245, 219, 19, 51, 117, 73, 20, 168],
71
+ [72, 185, 163, 103, 30, 35, 88, 132, 38, 154],
72
+ [86, 94, 211, 212, 11, 28, 206, 30, 39, 83],
73
+ [111, 88, 205, 131, 54, 135, 58, 57, 22, 139],
74
+ [31, 167, 46, 28, 328, 195, 33, 91, 39, 42],
75
+ [131, 81, 83, 123, 141, 331, 19, 18, 37, 36]
76
+ ]),
77
+ 'KNN': np.array([
78
+ [565, 12, 107, 20, 52, 6, 25, 4, 205, 4],
79
+ [195, 244, 121, 61, 120, 28, 31, 4, 178, 18],
80
+ [150, 7, 463, 58, 201, 24, 48, 10, 39, 0],
81
+ [108, 10, 279, 243, 133, 92, 80, 17, 32, 6],
82
+ [100, 6, 282, 54, 443, 23, 38, 11, 43, 0],
83
+ [89, 4, 254, 178, 143, 208, 68, 10, 41, 5],
84
+ [49, 1, 317, 102, 260, 24, 227, 2, 18, 0],
85
+ [127, 16, 226, 76, 236, 40, 45, 192, 41, 1],
86
+ [181, 31, 56, 45, 52, 11, 8, 4, 607, 5],
87
+ [192, 81, 127, 80, 117, 27, 41, 14, 205, 116]
88
+ ]),
89
+ 'Logistic Regression': np.array([
90
+ [424, 51, 58, 50, 25, 41, 17, 56, 202, 76],
91
+ [72, 426, 35, 48, 26, 36, 48, 47, 78, 184],
92
+ [96, 34, 266, 90, 123, 94, 130, 84, 51, 32],
93
+ [43, 52, 115, 235, 72, 194, 131, 56, 43, 59],
94
+ [55, 35, 137, 80, 280, 97, 152, 112, 25, 27],
95
+ [41, 41, 103, 202, 90, 300, 77, 64, 45, 37],
96
+ [14, 47, 97, 147, 94, 94, 417, 42, 23, 25],
97
+ [47, 47, 91, 70, 97, 91, 47, 397, 45, 68],
98
+ [146, 85, 31, 37, 17, 41, 10, 18, 513, 102],
99
+ [78, 180, 26, 36, 30, 29, 45, 59, 98, 419]
100
+ ]),
101
+ 'Naive Bayes': np.array([
102
+ [494, 20, 39, 10, 84, 34, 50, 9, 200, 60],
103
+ [141, 166, 24, 31, 66, 72, 192, 19, 121, 168],
104
+ [225, 24, 83, 15, 292, 48, 209, 21, 54, 29],
105
+ [163, 36, 54, 76, 151, 129, 262, 26, 34, 69],
106
+ [86, 8, 57, 26, 417, 38, 265, 22, 50, 31],
107
+ [156, 17, 55, 51, 167, 264, 159, 36, 57, 38],
108
+ [106, 2, 60, 18, 228, 46, 467, 15, 19, 39],
109
+ [134, 24, 36, 41, 228, 94, 102, 131, 72, 138],
110
+ [168, 41, 18, 17, 56, 83, 39, 8, 471, 99],
111
+ [144, 67, 17, 20, 48, 32, 101, 23, 141, 407]
112
+ ]),
113
+ 'Random Forest': np.array([
114
+ [559, 36, 62, 21, 28, 20, 20, 30, 165, 59],
115
+ [29, 544, 10, 38, 22, 34, 45, 35, 67, 176],
116
+ [100, 36, 337, 79, 137, 69, 123, 54, 34, 31],
117
+ [53, 46, 76, 282, 88, 173, 132, 62, 23, 65],
118
+ [54, 21, 139, 60, 381, 47, 158, 91, 27, 22],
119
+ [32, 27, 89, 167, 74, 400, 77, 74, 27, 33],
120
+ [11, 33, 87, 78, 99, 53, 567, 31, 6, 35],
121
+ [47, 44, 53, 58, 104, 83, 38, 451, 24, 98],
122
+ [90, 83, 19, 33, 15, 39, 7, 22, 615, 77],
123
+ [50, 176, 18, 38, 20, 24, 27, 43, 80, 524]
124
+ ])
125
  }
126
 
127
  # Create tabs for different visualizations
128
+ tab1, tab2, tab3 = st.tabs(["Overall Performance", "Confusion Matrices", "Individual Metrics"])
129
 
130
  with tab1:
131
+ st.header("Overall Model Performance")
132
+
133
+ # Create bar plot for overall accuracy
134
+ fig, ax = plt.subplots(figsize=(12, 6))
135
+ models = list(results['Accuracy'].keys())
136
+ accuracies = list(results['Accuracy'].values())
137
+
138
+ colors = ['purple', 'navy', 'teal', 'green', 'lime', 'yellow']
139
+ bars = ax.bar(models, accuracies, color=colors)
140
 
141
+ # Customize the plot
142
+ ax.set_title('Overall Model Performance Comparison')
143
+ ax.set_xlabel('Models')
144
+ ax.set_ylabel('Accuracy')
145
+ plt.xticks(rotation=45)
146
 
147
+ # Add value labels on top of bars
148
+ for bar in bars:
149
+ height = bar.get_height()
150
+ ax.text(bar.get_x() + bar.get_width()/2., height,
151
+ f'{height:.3f}',
152
+ ha='center', va='bottom')
 
 
153
 
154
+ plt.tight_layout()
155
+ st.pyplot(fig)
 
156
 
157
  with tab2:
158
  st.header("Confusion Matrices")
159
 
160
+ # Model selection for confusion matrix
161
  selected_model = st.selectbox("Select Model", list(confusion_matrices.keys()))
162
 
163
+ # Create confusion matrix plot using seaborn
164
+ fig, ax = plt.subplots(figsize=(10, 8))
165
+ sns.heatmap(confusion_matrices[selected_model],
166
+ annot=True,
167
+ fmt='d',
168
+ cmap='Blues',
169
+ ax=ax)
 
170
 
171
+ plt.title(f'Confusion Matrix - {selected_model}')
172
+ plt.xlabel('Predicted')
173
+ plt.ylabel('True')
174
+ plt.tight_layout()
175
 
176
+ st.pyplot(fig)
177
+
178
+ with tab3:
179
+ st.header("Individual Metrics")
 
 
 
 
 
180
 
181
+ col1, col2, col3 = st.columns(3)
182
+ metrics = ['Precision', 'Recall', 'F1']
 
 
 
183
 
184
+ for metric, col in zip(metrics, [col1, col2, col3]):
185
+ with col:
186
+ fig, ax = plt.subplots(figsize=(8, 6))
187
+ models = list(results[metric].keys())
188
+ values = list(results[metric].values())
189
+
190
+ ax.bar(models, values)
191
+ ax.set_title(f'Comparison of {metric}')
192
+ plt.xticks(rotation=45, ha='right')
193
+ ax.set_ylabel(metric)
194
+
195
+ plt.tight_layout()
196
+ st.pyplot(fig)
197
+
198
+ # Add metrics table to sidebar
199
+ st.sidebar.header("Metrics Table")
200
+ df_metrics = pd.DataFrame(results)
201
+ st.sidebar.dataframe(df_metrics.style.format("{:.3f}"))
202
 
203
+ # Add download button
204
  @st.cache_data
205
  def convert_df_to_csv():
206
+ return df_metrics.to_csv()
207
 
 
208
  csv = convert_df_to_csv()
209
  st.sidebar.download_button(
210
  label="Download metrics as CSV",
 
213
  mime='text/csv',
214
  )
215
 
 
 
 
 
 
 
 
 
 
216
  # Footer
217
  st.markdown("---")
218
+ st.markdown("Dashboard created for ML Models Comparison on CIFAR-10 dataset")