kovacsvi commited on
Commit
75bb7a1
·
1 Parent(s): 82b9aeb

refactored normalize_probs

Browse files
Files changed (1) hide show
  1. interfaces/cap_minor_media.py +12 -8
interfaces/cap_minor_media.py CHANGED
@@ -39,16 +39,20 @@ for code in CAP_MIN_CODES:
39
  major_to_minor_map[major_id].append(code)
40
  major_to_minor_map = dict(major_to_minor_map)
41
 
42
- def normalize_probs(probs: dict, n=None):
43
- probs_ = list(probs.values())
44
- if n:
45
- if len(probs_) > n:
46
- probs_.sort(reverse=True)
47
- probs_ = probs_[:5]
48
- values = np.array(probs_)
 
 
49
  exp_values = np.exp(values)
50
  sum_exp = np.sum(exp_values)
51
- return {k: float(v) for k, v in zip(probs.keys(), exp_values / sum_exp)}
 
 
52
 
53
  def check_huggingface_path(checkpoint_path: str):
54
  try:
 
39
  major_to_minor_map[major_id].append(code)
40
  major_to_minor_map = dict(major_to_minor_map)
41
 
42
+ def normalize_probs(probs: dict, n: int = None) -> dict:
43
+ if n is not None and len(probs) > n:
44
+ # Sort items by value in descending order and keep top-n
45
+ top_items = sorted(probs.items(), key=lambda item: item[1], reverse=True)[:n]
46
+ else:
47
+ top_items = list(probs.items())
48
+
49
+ keys, values = zip(*top_items)
50
+ values = np.array(values)
51
  exp_values = np.exp(values)
52
  sum_exp = np.sum(exp_values)
53
+ normalized = exp_values / sum_exp
54
+
55
+ return dict(zip(keys, map(float, normalized)))
56
 
57
  def check_huggingface_path(checkpoint_path: str):
58
  try: