Spaces:
Build error
Build error
Update data
Browse files- .gitattributes +3 -0
- README.md +3 -3
- app.py +66 -0
- data/full_pred_test_w_plurals_w_iou.json +3 -0
- data/full_pred_val_w_plurals_w_iou.json +3 -0
- data/saiapr_tc-12.zip +3 -0
- requirements.txt +6 -0
- utils.py +167 -0
.gitattributes
CHANGED
|
@@ -29,3 +29,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 30 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 30 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
data/saiapr_tc-12.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
data/full_pred_val_w_plurals_w_iou.json filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
data/full_pred_test_w_plurals_w_iou.json filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
title: Categories Error Analysis
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 3.2
|
| 8 |
app_file: app.py
|
|
|
|
| 1 |
---
|
| 2 |
title: Categories Error Analysis
|
| 3 |
+
emoji: π± π
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: gray
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 3.2
|
| 8 |
app_file: app.py
|
app.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from turtle import width
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from utils import SampleClass
|
| 4 |
+
|
| 5 |
+
sample = SampleClass()
|
| 6 |
+
|
| 7 |
+
# --- Interface ---
|
| 8 |
+
|
| 9 |
+
demo = gr.Blocks(
|
| 10 |
+
title="Categories_error_analysis.ipynb",
|
| 11 |
+
css=".container { max-width: 98%; margin: auto;}; #md {width: 30%} #large {width: 70%}"
|
| 12 |
+
# css="#md {width: 30%} #large {width: 70%}"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
with demo:
|
| 16 |
+
gr.Markdown("<h2><center> π± Categories Error Analysis π</center></h2>")
|
| 17 |
+
with gr.Row():
|
| 18 |
+
with gr.Column():
|
| 19 |
+
with gr.Row():
|
| 20 |
+
with gr.Column():
|
| 21 |
+
category = gr.Dropdown(
|
| 22 |
+
label="Category",
|
| 23 |
+
value="relational",
|
| 24 |
+
choices=["intrinsic","spatial","ordinal","relational","plural"])
|
| 25 |
+
with gr.Column():
|
| 26 |
+
predictions = gr.Dropdown(
|
| 27 |
+
label='Predictions',
|
| 28 |
+
value='fail',
|
| 29 |
+
choices=["fail", "correct"])
|
| 30 |
+
with gr.Row():
|
| 31 |
+
with gr.Column():
|
| 32 |
+
model = gr.Dropdown(
|
| 33 |
+
label='Model',
|
| 34 |
+
value='baseline',
|
| 35 |
+
choices=["baseline", "extended"])
|
| 36 |
+
with gr.Column():
|
| 37 |
+
split = gr.Dropdown(
|
| 38 |
+
label='Split',
|
| 39 |
+
value='val',
|
| 40 |
+
choices=["test","val"])
|
| 41 |
+
with gr.Row():
|
| 42 |
+
with gr.Column():
|
| 43 |
+
username = gr.Dropdown(
|
| 44 |
+
label="UserName",
|
| 45 |
+
value="luciana",
|
| 46 |
+
choices=["luciana",'mauri','jorge','nano'])
|
| 47 |
+
with gr.Column():
|
| 48 |
+
next_idx_sample = gr.Number(
|
| 49 |
+
label='Next Idx Sample',
|
| 50 |
+
value=0)
|
| 51 |
+
with gr.Row():
|
| 52 |
+
progress = gr.Label(label='Progress',num_top_classes=10)
|
| 53 |
+
with gr.Row():
|
| 54 |
+
btn_next = gr.Button("Get Next Sample")
|
| 55 |
+
|
| 56 |
+
with gr.Column():
|
| 57 |
+
with gr.Row(): info = gr.Text(label="Sample Info")
|
| 58 |
+
with gr.Row(): img = gr.Image(label="Sample", type="numpy")
|
| 59 |
+
|
| 60 |
+
btn_next.click(
|
| 61 |
+
fn=sample.explorateSamples,
|
| 62 |
+
inputs=[username,predictions,category,model,split,next_idx_sample],
|
| 63 |
+
outputs=[next_idx_sample, progress, img, info])
|
| 64 |
+
|
| 65 |
+
# demo.queue(concurrency_count=10)
|
| 66 |
+
demo.launch(debug=False)
|
data/full_pred_test_w_plurals_w_iou.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:32d8075d6de0d60702cdfd884b43fe23140c7a10e28ddc865a03bc6cb84669fb
|
| 3 |
+
size 28012574
|
data/full_pred_val_w_plurals_w_iou.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e285882c1ad5c8ad5d9485bfce5f043b909564c5b68a61cf925c2c548b5a1d24
|
| 3 |
+
size 2581176
|
data/saiapr_tc-12.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cbeb079b66dd88ba58d15c5c421e983a65347527418228ad55022b7535983b35
|
| 3 |
+
size 2751748544
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
matplotlib
|
| 2 |
+
numpy
|
| 3 |
+
pandas
|
| 4 |
+
pillow
|
| 5 |
+
zipfile36
|
| 6 |
+
zipfile38
|
utils.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib as mpl
|
| 2 |
+
mpl.use('Agg')
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import matplotlib.patches as patches
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from zipfile import ZipFile
|
| 10 |
+
import gradio as gr
|
| 11 |
+
|
| 12 |
+
class SampleClass:
|
| 13 |
+
def __init__(self):
|
| 14 |
+
self.test_df = pd.read_json("data/full_pred_test_w_plurals_w_iou.json")
|
| 15 |
+
self.val_df = pd.read_json("data/full_pred_val_w_plurals_w_iou.json")
|
| 16 |
+
self.zip_file = ZipFile("data/saiapr_tc-12.zip", 'r')
|
| 17 |
+
self.filtered_df = None
|
| 18 |
+
|
| 19 |
+
def __get(self, img_path):
|
| 20 |
+
img_obj = self.zip_file.open(img_path)
|
| 21 |
+
pill_img = Image.open(img_obj)
|
| 22 |
+
img = np.array(pill_img)
|
| 23 |
+
return img
|
| 24 |
+
|
| 25 |
+
def __loadPredictions(self, split, model):
|
| 26 |
+
|
| 27 |
+
assert(split in ['test','val'])
|
| 28 |
+
assert(model in ['baseline','extended'])
|
| 29 |
+
|
| 30 |
+
if split == "test":
|
| 31 |
+
df = self.test_df
|
| 32 |
+
elif split == "val":
|
| 33 |
+
df = self.val_df
|
| 34 |
+
else:
|
| 35 |
+
raise ValueError("File not available yet")
|
| 36 |
+
|
| 37 |
+
if model == 'baseline':
|
| 38 |
+
df = df.rename(columns={'baseline_hit':'hit', 'baseline_pred':'predictions',
|
| 39 |
+
'extended_hit':'hit_other', 'extended_pred':'predictions_other',
|
| 40 |
+
'baseline_iou':'iou',
|
| 41 |
+
'extended_iou':'iou_other'}
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
elif model == 'extended':
|
| 45 |
+
df = df.rename(columns={'extended_hit':'hit', 'extended_pred':'predictions',
|
| 46 |
+
'baseline_hit':'hit_other', 'baseline_pred':'predictions_other',
|
| 47 |
+
'extended_iou':'iou',
|
| 48 |
+
'baseline_iou':'iou_other'}
|
| 49 |
+
)
|
| 50 |
+
return df
|
| 51 |
+
|
| 52 |
+
def __getSample(self, id):
|
| 53 |
+
sample = self.filtered_df[self.filtered_df.sample_idx == id]
|
| 54 |
+
|
| 55 |
+
sent = sample['sent'].values[0]
|
| 56 |
+
pos_tags = sample['pos_tags'].values[0]
|
| 57 |
+
plural_tks = sample['plural_tks'].values[0]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
cat_intrinsic = sample['intrinsic'].values[0]
|
| 61 |
+
cat_spatial = sample['spatial'].values[0]
|
| 62 |
+
cat_ordinal = sample['ordinal'].values[0]
|
| 63 |
+
cat_relational = sample['relational'].values[0]
|
| 64 |
+
cat_plural = sample['plural'].values[0]
|
| 65 |
+
categories = [('instrinsic',cat_intrinsic),
|
| 66 |
+
('spatial',cat_spatial),
|
| 67 |
+
('ordinal',cat_ordinal),
|
| 68 |
+
('relational',cat_relational),
|
| 69 |
+
('plural',cat_plural)]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
img_path = "saiapr_tc-12"+sample['file_path'].values[0].split("saiapr_tc-12")[1]
|
| 73 |
+
# print(img_path)
|
| 74 |
+
|
| 75 |
+
hit = sample['hit'].values[0]
|
| 76 |
+
hit_o = sample['hit_other'].values[0]
|
| 77 |
+
|
| 78 |
+
iou = sample['iou'].values[0]
|
| 79 |
+
iou_o = sample['iou_other'].values[0]
|
| 80 |
+
|
| 81 |
+
prediction = {0:' FAIL ',1:' CORRECT '}
|
| 82 |
+
|
| 83 |
+
bbox_gt = sample['bbox'].values[0]
|
| 84 |
+
x1_gt,y1_gt,x2_gt,y2_gt = bbox_gt
|
| 85 |
+
# x1_gt,y1_gt,x2_gt,y2_gt = tuple(map(float,bbox_gt[1:-1].split(",")))
|
| 86 |
+
|
| 87 |
+
bp_bbox = sample['predictions'].values[0]
|
| 88 |
+
x1_pred,y1_pred,x2_pred,y2_pred = bp_bbox
|
| 89 |
+
# x1_pred,y1_pred,x2_pred,y2_pred = tuple(map(float,bp_bbox[1:-1].split(",")))
|
| 90 |
+
|
| 91 |
+
bp_o_bbox = sample['predictions_other'].values[0]
|
| 92 |
+
x1_pred_o,y1_pred_o,x2_pred_o,y2_pred_o = bp_o_bbox
|
| 93 |
+
# x1_pred_o,y1_pred_o,x2_pred_o,y2_pred_o = tuple(map(float,bp_o_bbox[1:-1].split(",")))
|
| 94 |
+
|
| 95 |
+
# Plot
|
| 96 |
+
fig, ax = plt.subplots(1)
|
| 97 |
+
ax.imshow(self.__get(img_path), interpolation='bilinear')
|
| 98 |
+
|
| 99 |
+
# Create bbox's
|
| 100 |
+
rect_gt = patches.Rectangle((x1_gt,y1_gt), (x2_gt-x1_gt),(y2_gt-y1_gt),
|
| 101 |
+
linewidth=2, edgecolor='blue', facecolor='None') #fill=True, alpha=.3
|
| 102 |
+
|
| 103 |
+
rect_pred = patches.Rectangle((x1_pred,y1_pred), (x2_pred-x1_pred),(y2_pred-y1_pred),
|
| 104 |
+
linewidth=2, edgecolor='lightgreen', facecolor='none')
|
| 105 |
+
|
| 106 |
+
rect_pred_o = patches.Rectangle((x1_pred_o,y1_pred_o), (x2_pred_o-x1_pred_o),(y2_pred_o-y1_pred_o),
|
| 107 |
+
linewidth=2, edgecolor='red', facecolor='none')
|
| 108 |
+
|
| 109 |
+
ax.add_patch(rect_gt)
|
| 110 |
+
ax.add_patch(rect_pred)
|
| 111 |
+
ax.add_patch(rect_pred_o)
|
| 112 |
+
ax.axis('off')
|
| 113 |
+
|
| 114 |
+
info = {'Expresion':sent,
|
| 115 |
+
'Idx Sample':str(id),
|
| 116 |
+
'IoU': str(round(iou,2)) + "("+prediction[hit]+")",
|
| 117 |
+
'IoU other': str(round(iou_o,2)) + "("+prediction[hit_o]+")",
|
| 118 |
+
'Pos Tags':str(pos_tags),
|
| 119 |
+
'PluralTks ':plural_tks,
|
| 120 |
+
'Categories':",".join([c for c,b in categories if b])
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
plt.title(info['Expresion'], fontsize=12)
|
| 124 |
+
plt.tight_layout()
|
| 125 |
+
plt.close(fig)
|
| 126 |
+
|
| 127 |
+
fig.canvas.draw()
|
| 128 |
+
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
| 129 |
+
w, h = fig.canvas.get_width_height()
|
| 130 |
+
img = data.reshape((int(h), int(w), -1))
|
| 131 |
+
return info, img
|
| 132 |
+
|
| 133 |
+
def explorateSamples(self,
|
| 134 |
+
username,
|
| 135 |
+
predictions,
|
| 136 |
+
category,
|
| 137 |
+
model,
|
| 138 |
+
split,
|
| 139 |
+
next_idx_sample):
|
| 140 |
+
|
| 141 |
+
next_idx_sample = int(next_idx_sample)
|
| 142 |
+
hit = {'fail':0,'correct':1}
|
| 143 |
+
df = self.__loadPredictions(split, model)
|
| 144 |
+
self.filtered_df = df[(df[category] == 1) & (df.hit == hit[predictions])]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
all_idx_samples = self.filtered_df.sample_idx.to_list()
|
| 148 |
+
parts = np.array_split(list(all_idx_samples), 4)
|
| 149 |
+
user_ids = {
|
| 150 |
+
'luciana':list(parts[0]),
|
| 151 |
+
'mauri':list(parts[1]),
|
| 152 |
+
'jorge':list(parts[2]),
|
| 153 |
+
'nano':list(parts[3])
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
try:
|
| 157 |
+
id_ = user_ids[username].index(next_idx_sample)
|
| 158 |
+
except:
|
| 159 |
+
id_ = 0
|
| 160 |
+
|
| 161 |
+
next_idx_sample = user_ids[username][ min(id_+1, len(user_ids[username])-1) ]
|
| 162 |
+
progress = {f"{id_}/{len(user_ids[username])-1}":id_/(len(user_ids[username])-1)}
|
| 163 |
+
info, img = self.__getSample(user_ids[username][id_])
|
| 164 |
+
info = "".join([str(k)+":\t"+str(v)+"\n" for k,v in list(info.items())[1:]]).strip()
|
| 165 |
+
|
| 166 |
+
return (gr.Number.update(value=next_idx_sample),progress,img,info)
|
| 167 |
+
|