Spaces:
Runtime error
Runtime error
interactive legend
Browse files
app.py
CHANGED
|
@@ -80,8 +80,8 @@ def data_comparison(df):
|
|
| 80 |
).interactive()
|
| 81 |
|
| 82 |
legend = alt.Chart(df).mark_point(size=100, filled=True).encode(
|
| 83 |
-
x=alt.X("label"),
|
| 84 |
-
y=alt.Y('cluster:N', axis=alt.Axis(orient='right'), title=
|
| 85 |
shape=alt.Shape('label:N', scale=alt.Scale(
|
| 86 |
range=['circle', 'diamond']), legend=None),
|
| 87 |
color=color,
|
|
@@ -247,6 +247,22 @@ if __name__ == "__main__":
|
|
| 247 |
data_df['slice'] = 'high-loss'
|
| 248 |
data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss')
|
| 249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
with rcol:
|
| 251 |
with st.spinner(text='loading...'):
|
| 252 |
st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
|
|
@@ -264,20 +280,6 @@ if __name__ == "__main__":
|
|
| 264 |
if run_kmeans == 'True':
|
| 265 |
with st.spinner(text='running kmeans...'):
|
| 266 |
merged = kmeans(data_df,num_clusters=num_clusters)
|
| 267 |
-
|
| 268 |
-
st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
|
| 269 |
-
with st.expander("How to read the table:"):
|
| 270 |
-
st.markdown("* *Error slice* refers to the subset of evaluation dataset the model performs poorly on.")
|
| 271 |
-
st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
|
| 272 |
-
st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
|
| 273 |
-
with st.spinner(text='loading error slice...'):
|
| 274 |
-
dataframe=read_file_to_df('./assets/data/'+dataset+ '_'+ model+'_error-slices.parquet')
|
| 275 |
-
#uncomment the next next line to run dynamically and not from file
|
| 276 |
-
# dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
|
| 277 |
-
# by=['loss'], ascending=False)
|
| 278 |
-
# table_html = dataframe.to_html(
|
| 279 |
-
# columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
|
| 280 |
-
# table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
|
| 281 |
-
st.write(dataframe,width=900, height=300)
|
| 282 |
with st.spinner(text='loading visualization...'):
|
| 283 |
quant_panel(merged)
|
|
|
|
| 80 |
).interactive()
|
| 81 |
|
| 82 |
legend = alt.Chart(df).mark_point(size=100, filled=True).encode(
|
| 83 |
+
x=alt.X("label:N"),
|
| 84 |
+
y=alt.Y('cluster:N', axis=alt.Axis(orient='right'), sort='descending', title=''),
|
| 85 |
shape=alt.Shape('label:N', scale=alt.Scale(
|
| 86 |
range=['circle', 'diamond']), legend=None),
|
| 87 |
color=color,
|
|
|
|
| 247 |
data_df['slice'] = 'high-loss'
|
| 248 |
data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss')
|
| 249 |
|
| 250 |
+
with lcol:
|
| 251 |
+
st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
|
| 252 |
+
with st.expander("How to read the table:"):
|
| 253 |
+
st.markdown("* *Error slice* refers to the subset of evaluation dataset the model performs poorly on.")
|
| 254 |
+
st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
|
| 255 |
+
st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
|
| 256 |
+
with st.spinner(text='loading error slice...'):
|
| 257 |
+
dataframe=read_file_to_df('./assets/data/'+dataset+ '_'+ model+'_error-slices.parquet')
|
| 258 |
+
#uncomment the next next line to run dynamically and not from file
|
| 259 |
+
# dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
|
| 260 |
+
# by=['loss'], ascending=False)
|
| 261 |
+
# table_html = dataframe.to_html(
|
| 262 |
+
# columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
|
| 263 |
+
# table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
|
| 264 |
+
st.write(dataframe,width=900, height=300)
|
| 265 |
+
|
| 266 |
with rcol:
|
| 267 |
with st.spinner(text='loading...'):
|
| 268 |
st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
|
|
|
|
| 280 |
if run_kmeans == 'True':
|
| 281 |
with st.spinner(text='running kmeans...'):
|
| 282 |
merged = kmeans(data_df,num_clusters=num_clusters)
|
| 283 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
with st.spinner(text='loading visualization...'):
|
| 285 |
quant_panel(merged)
|