Corey Morris
commited on
Commit
·
c671de9
1
Parent(s):
ed019c6
added MMLU overall average column. added a few charts comparing more moral reasoning and comparing MMLU overall to other data
Browse files
app.py
CHANGED
|
@@ -25,15 +25,16 @@ class MultiURLData:
|
|
| 25 |
data = json.load(f)
|
| 26 |
df = pd.DataFrame(data['results']).T
|
| 27 |
|
| 28 |
-
df = df.rename(columns={'acc': model_name})
|
| 29 |
-
|
| 30 |
-
df.index = df.index.str.replace('hendrycksTest-', '', regex=True)
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
df.index = df.index.str.replace('harness\|', '', regex=True)
|
| 33 |
-
|
| 34 |
# remove |5 from the index
|
| 35 |
df.index = df.index.str.replace('\|5', '', regex=True)
|
| 36 |
|
|
|
|
| 37 |
dataframes.append(df[[model_name]])
|
| 38 |
|
| 39 |
data = pd.concat(dataframes, axis=1)
|
|
@@ -44,7 +45,18 @@ class MultiURLData:
|
|
| 44 |
cols = cols[-1:] + cols[:-1]
|
| 45 |
data = data[cols]
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
return data
|
|
|
|
|
|
|
| 48 |
|
| 49 |
def get_data(self, selected_models):
|
| 50 |
filtered_data = self.data[self.data['Model Name'].isin(selected_models)]
|
|
@@ -75,6 +87,7 @@ selected_models = st.multiselect(
|
|
| 75 |
|
| 76 |
|
| 77 |
# Get the filtered data and display it in a table
|
|
|
|
| 78 |
filtered_data = data_provider.get_data(selected_models)
|
| 79 |
st.dataframe(filtered_data)
|
| 80 |
|
|
@@ -111,11 +124,34 @@ def create_plot(df, model_column, arc_column, moral_column, models=None):
|
|
| 111 |
# models_to_plot = ['Model1', 'Model2', 'Model3']
|
| 112 |
# fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'moral_scenarios|5', models=models_to_plot)
|
| 113 |
|
| 114 |
-
|
| 115 |
-
st.plotly_chart(fig)
|
| 116 |
|
| 117 |
fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'hellaswag|10')
|
| 118 |
st.plotly_chart(fig)
|
| 119 |
|
| 120 |
-
fig = create_plot(filtered_data, 'Model Name', '
|
| 121 |
st.plotly_chart(fig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
data = json.load(f)
|
| 26 |
df = pd.DataFrame(data['results']).T
|
| 27 |
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
# data cleanup
|
| 30 |
+
df = df.rename(columns={'acc': model_name})
|
| 31 |
+
# Replace 'hendrycksTest-' with a more descriptive column name
|
| 32 |
+
df.index = df.index.str.replace('hendrycksTest-', 'MMLU_', regex=True)
|
| 33 |
df.index = df.index.str.replace('harness\|', '', regex=True)
|
|
|
|
| 34 |
# remove |5 from the index
|
| 35 |
df.index = df.index.str.replace('\|5', '', regex=True)
|
| 36 |
|
| 37 |
+
|
| 38 |
dataframes.append(df[[model_name]])
|
| 39 |
|
| 40 |
data = pd.concat(dataframes, axis=1)
|
|
|
|
| 45 |
cols = cols[-1:] + cols[:-1]
|
| 46 |
data = data[cols]
|
| 47 |
|
| 48 |
+
# create a new column that averages the results from each of the columns with a name that start with MMLU
|
| 49 |
+
data['MMLU_average'] = data.filter(regex='MMLU').mean(axis=1)
|
| 50 |
+
|
| 51 |
+
# move the MMLU_average column to the the second column in the dataframe
|
| 52 |
+
cols = data.columns.tolist()
|
| 53 |
+
cols = cols[:1] + cols[-1:] + cols[1:-1]
|
| 54 |
+
data = data[cols]
|
| 55 |
+
data
|
| 56 |
+
|
| 57 |
return data
|
| 58 |
+
|
| 59 |
+
|
| 60 |
|
| 61 |
def get_data(self, selected_models):
|
| 62 |
filtered_data = self.data[self.data['Model Name'].isin(selected_models)]
|
|
|
|
| 87 |
|
| 88 |
|
| 89 |
# Get the filtered data and display it in a table
|
| 90 |
+
st.header('Sortable table')
|
| 91 |
filtered_data = data_provider.get_data(selected_models)
|
| 92 |
st.dataframe(filtered_data)
|
| 93 |
|
|
|
|
| 124 |
# models_to_plot = ['Model1', 'Model2', 'Model3']
|
| 125 |
# fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'moral_scenarios|5', models=models_to_plot)
|
| 126 |
|
| 127 |
+
st.header('Overall benchmark comparison')
|
|
|
|
| 128 |
|
| 129 |
fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'hellaswag|10')
|
| 130 |
st.plotly_chart(fig)
|
| 131 |
|
| 132 |
+
fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'MMLU_average')
|
| 133 |
st.plotly_chart(fig)
|
| 134 |
+
|
| 135 |
+
fig = create_plot(filtered_data, 'Model Name', 'hellaswag|10', 'MMLU_average')
|
| 136 |
+
st.plotly_chart(fig)
|
| 137 |
+
|
| 138 |
+
# Add heading to page to say Moral Scenarios
|
| 139 |
+
st.header('Moral Scenarios')
|
| 140 |
+
|
| 141 |
+
fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'MMLU_moral_scenarios')
|
| 142 |
+
st.plotly_chart(fig)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
fig = create_plot(filtered_data, 'Model Name', 'MMLU_moral_disputes', 'MMLU_moral_scenarios')
|
| 146 |
+
st.plotly_chart(fig)
|
| 147 |
+
|
| 148 |
+
fig = create_plot(filtered_data, 'Model Name', 'MMLU_average', 'MMLU_moral_scenarios')
|
| 149 |
+
st.plotly_chart(fig)
|
| 150 |
+
|
| 151 |
+
# create a histogram of moral scenarios
|
| 152 |
+
fig = px.histogram(filtered_data, x="MMLU_moral_scenarios", marginal="rug", hover_data=filtered_data.columns)
|
| 153 |
+
st.plotly_chart(fig)
|
| 154 |
+
|
| 155 |
+
# create a histogram of moral disputes
|
| 156 |
+
fig = px.histogram(filtered_data, x="MMLU_moral_disputes", marginal="rug", hover_data=filtered_data.columns)
|
| 157 |
+
st.plotly_chart(fig)
|