Spaces:
Running
Running
kovacsvi
commited on
Commit
·
82b9aeb
1
Parent(s):
952b4ed
removed top n logic
Browse files
interfaces/cap_minor_media.py
CHANGED
@@ -25,7 +25,7 @@ domains = {
|
|
25 |
"media": "media"
|
26 |
}
|
27 |
|
28 |
-
|
29 |
CAP_MEDIA_CODES = list(CAP_MEDIA_NUM_DICT.values())
|
30 |
CAP_MIN_CODES = list(CAP_MIN_NUM_DICT.values())
|
31 |
|
@@ -39,11 +39,12 @@ for code in CAP_MIN_CODES:
|
|
39 |
major_to_minor_map[major_id].append(code)
|
40 |
major_to_minor_map = dict(major_to_minor_map)
|
41 |
|
42 |
-
def normalize_probs(probs: dict, n
|
43 |
probs_ = list(probs.values())
|
44 |
-
if
|
45 |
-
probs_
|
46 |
-
|
|
|
47 |
values = np.array(probs_)
|
48 |
exp_values = np.exp(values)
|
49 |
sum_exp = np.sum(exp_values)
|
@@ -99,7 +100,7 @@ def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
|
|
99 |
i: float(major_probs_np[i])
|
100 |
for i in np.argsort(major_probs_np)[::-1]
|
101 |
}
|
102 |
-
filtered_probs = normalize_probs(filtered_probs
|
103 |
|
104 |
output_pred = {
|
105 |
f"[{major_index_to_id[k]}] {CAP_MEDIA_LABEL_NAMES[major_index_to_id[k]]}": v
|
@@ -117,7 +118,7 @@ def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
|
|
117 |
# Restrict to valid minor codes
|
118 |
valid_indices = [minor_id_to_index[mid] for mid in valid_minor_ids if mid in minor_id_to_index]
|
119 |
filtered_probs = {minor_index_to_id[i]: float(minor_probs[0][i]) for i in valid_indices}
|
120 |
-
filtered_probs = normalize_probs(filtered_probs
|
121 |
|
122 |
output_pred = {
|
123 |
f"[{top_major_id}] {CAP_MEDIA_LABEL_NAMES[top_major_id]} [{k}] {CAP_MIN_LABEL_NAMES[k]}": v
|
@@ -161,4 +162,4 @@ demo = gr.Interface(
|
|
161 |
gr.Textbox(lines=6, label="Input"),
|
162 |
gr.Dropdown(languages, label="Language"),
|
163 |
gr.Dropdown(domains.keys(), label="Domain")],
|
164 |
-
outputs=[gr.Label(
|
|
|
25 |
"media": "media"
|
26 |
}
|
27 |
|
28 |
+
NUM_TOP_CLASSES = 5
|
29 |
CAP_MEDIA_CODES = list(CAP_MEDIA_NUM_DICT.values())
|
30 |
CAP_MIN_CODES = list(CAP_MIN_NUM_DICT.values())
|
31 |
|
|
|
39 |
major_to_minor_map[major_id].append(code)
|
40 |
major_to_minor_map = dict(major_to_minor_map)
|
41 |
|
42 |
+
def normalize_probs(probs: dict, n=None):
|
43 |
probs_ = list(probs.values())
|
44 |
+
if n:
|
45 |
+
if len(probs_) > n:
|
46 |
+
probs_.sort(reverse=True)
|
47 |
+
probs_ = probs_[:5]
|
48 |
values = np.array(probs_)
|
49 |
exp_values = np.exp(values)
|
50 |
sum_exp = np.sum(exp_values)
|
|
|
100 |
i: float(major_probs_np[i])
|
101 |
for i in np.argsort(major_probs_np)[::-1]
|
102 |
}
|
103 |
+
filtered_probs = normalize_probs(filtered_probs)
|
104 |
|
105 |
output_pred = {
|
106 |
f"[{major_index_to_id[k]}] {CAP_MEDIA_LABEL_NAMES[major_index_to_id[k]]}": v
|
|
|
118 |
# Restrict to valid minor codes
|
119 |
valid_indices = [minor_id_to_index[mid] for mid in valid_minor_ids if mid in minor_id_to_index]
|
120 |
filtered_probs = {minor_index_to_id[i]: float(minor_probs[0][i]) for i in valid_indices}
|
121 |
+
filtered_probs = normalize_probs(filtered_probs)
|
122 |
|
123 |
output_pred = {
|
124 |
f"[{top_major_id}] {CAP_MEDIA_LABEL_NAMES[top_major_id]} [{k}] {CAP_MIN_LABEL_NAMES[k]}": v
|
|
|
162 |
gr.Textbox(lines=6, label="Input"),
|
163 |
gr.Dropdown(languages, label="Language"),
|
164 |
gr.Dropdown(domains.keys(), label="Domain")],
|
165 |
+
outputs=[gr.Label(label="Output"), gr.Markdown()])
|