kovacsvi commited on
Commit
82b9aeb
·
1 Parent(s): 952b4ed

removed top n logic

Browse files
Files changed (1) hide show
  1. interfaces/cap_minor_media.py +9 -8
interfaces/cap_minor_media.py CHANGED
@@ -25,7 +25,7 @@ domains = {
25
  "media": "media"
26
  }
27
 
28
-
29
  CAP_MEDIA_CODES = list(CAP_MEDIA_NUM_DICT.values())
30
  CAP_MIN_CODES = list(CAP_MIN_NUM_DICT.values())
31
 
@@ -39,11 +39,12 @@ for code in CAP_MIN_CODES:
39
  major_to_minor_map[major_id].append(code)
40
  major_to_minor_map = dict(major_to_minor_map)
41
 
42
- def normalize_probs(probs: dict, n: int):
43
  probs_ = list(probs.values())
44
- if len(probs_) > n:
45
- probs_.sort(reverse=True)
46
- probs_ = probs_[:5]
 
47
  values = np.array(probs_)
48
  exp_values = np.exp(values)
49
  sum_exp = np.sum(exp_values)
@@ -99,7 +100,7 @@ def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
99
  i: float(major_probs_np[i])
100
  for i in np.argsort(major_probs_np)[::-1]
101
  }
102
- filtered_probs = normalize_probs(filtered_probs, n=5)
103
 
104
  output_pred = {
105
  f"[{major_index_to_id[k]}] {CAP_MEDIA_LABEL_NAMES[major_index_to_id[k]]}": v
@@ -117,7 +118,7 @@ def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
117
  # Restrict to valid minor codes
118
  valid_indices = [minor_id_to_index[mid] for mid in valid_minor_ids if mid in minor_id_to_index]
119
  filtered_probs = {minor_index_to_id[i]: float(minor_probs[0][i]) for i in valid_indices}
120
- filtered_probs = normalize_probs(filtered_probs, n=5)
121
 
122
  output_pred = {
123
  f"[{top_major_id}] {CAP_MEDIA_LABEL_NAMES[top_major_id]} [{k}] {CAP_MIN_LABEL_NAMES[k]}": v
@@ -161,4 +162,4 @@ demo = gr.Interface(
161
  gr.Textbox(lines=6, label="Input"),
162
  gr.Dropdown(languages, label="Language"),
163
  gr.Dropdown(domains.keys(), label="Domain")],
164
- outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])
 
25
  "media": "media"
26
  }
27
 
28
+ NUM_TOP_CLASSES = 5
29
  CAP_MEDIA_CODES = list(CAP_MEDIA_NUM_DICT.values())
30
  CAP_MIN_CODES = list(CAP_MIN_NUM_DICT.values())
31
 
 
39
  major_to_minor_map[major_id].append(code)
40
  major_to_minor_map = dict(major_to_minor_map)
41
 
42
+ def normalize_probs(probs: dict, n=None):
43
  probs_ = list(probs.values())
44
+ if n:
45
+ if len(probs_) > n:
46
+ probs_.sort(reverse=True)
47
+ probs_ = probs_[:5]
48
  values = np.array(probs_)
49
  exp_values = np.exp(values)
50
  sum_exp = np.sum(exp_values)
 
100
  i: float(major_probs_np[i])
101
  for i in np.argsort(major_probs_np)[::-1]
102
  }
103
+ filtered_probs = normalize_probs(filtered_probs)
104
 
105
  output_pred = {
106
  f"[{major_index_to_id[k]}] {CAP_MEDIA_LABEL_NAMES[major_index_to_id[k]]}": v
 
118
  # Restrict to valid minor codes
119
  valid_indices = [minor_id_to_index[mid] for mid in valid_minor_ids if mid in minor_id_to_index]
120
  filtered_probs = {minor_index_to_id[i]: float(minor_probs[0][i]) for i in valid_indices}
121
+ filtered_probs = normalize_probs(filtered_probs)
122
 
123
  output_pred = {
124
  f"[{top_major_id}] {CAP_MEDIA_LABEL_NAMES[top_major_id]} [{k}] {CAP_MIN_LABEL_NAMES[k]}": v
 
162
  gr.Textbox(lines=6, label="Input"),
163
  gr.Dropdown(languages, label="Language"),
164
  gr.Dropdown(domains.keys(), label="Domain")],
165
+ outputs=[gr.Label(label="Output"), gr.Markdown()])