kovacsvi commited on
Commit
a1eb2dd
·
1 Parent(s): b2b1d1f

normalize filtered probs (minor)

Browse files
Files changed (1) hide show
  1. interfaces/cap_minor_media.py +14 -0
interfaces/cap_minor_media.py CHANGED
@@ -39,6 +39,18 @@ for code in CAP_MIN_CODES:
39
  major_to_minor_map[major_id].append(code)
40
  major_to_minor_map = dict(major_to_minor_map)
41
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def check_huggingface_path(checkpoint_path: str):
44
  try:
@@ -102,6 +114,8 @@ def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
102
  # Restrict to valid minor codes
103
  valid_indices = [minor_id_to_index[mid] for mid in valid_minor_ids if mid in minor_id_to_index]
104
  filtered_probs = {minor_index_to_id[i]: float(minor_probs[0][i]) for i in valid_indices}
 
 
105
  output_pred = {
106
  f"[{k}] {CAP_MIN_LABEL_NAMES[k]}": v
107
  for k, v in sorted(filtered_probs.items(), key=lambda item: item[1], reverse=True)
 
39
  major_to_minor_map[major_id].append(code)
40
  major_to_minor_map = dict(major_to_minor_map)
41
 
42
+ def normalize_probs(probs: dict):
43
+ if not probs:
44
+ return {}
45
+
46
+ min_val = min(probs.values())
47
+ max_val = max(probs.values())
48
+ range_val = max_val - min_val
49
+
50
+ if range_val == 0:
51
+ return {k: 1.0 for k in probs}
52
+
53
+ return {k: (v - min_val) / range_val for k, v in probs.items()}
54
 
55
  def check_huggingface_path(checkpoint_path: str):
56
  try:
 
114
  # Restrict to valid minor codes
115
  valid_indices = [minor_id_to_index[mid] for mid in valid_minor_ids if mid in minor_id_to_index]
116
  filtered_probs = {minor_index_to_id[i]: float(minor_probs[0][i]) for i in valid_indices}
117
+ filtered_probs = normalize_probs(filtered_probs)
118
+
119
  output_pred = {
120
  f"[{k}] {CAP_MIN_LABEL_NAMES[k]}": v
121
  for k, v in sorted(filtered_probs.items(), key=lambda item: item[1], reverse=True)