Rolv-Arild commited on
Commit
7af02d9
·
verified ·
1 Parent(s): d00e0ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -117,15 +117,15 @@ def infer(model, replay_file,
117
  remove_ties_mask = is_ot if not ignore_ties else torch.ones(len(preds), dtype=torch.bool)
118
  remove_ties_mask = remove_ties_mask.numpy()
119
  if remove_ties_mask.any():
120
- tie_probs = preds[remove_ties_mask, "Tie"]
121
  q = (1 - tie_probs)
122
  for c in preds.columns:
123
  if c.startswith("Blue") or c.startswith("Orange"):
124
- preds[remove_ties_mask, c] /= q
125
  if ignore_ties:
126
  preds = preds.drop("Tie", axis=1)
127
  else:
128
- preds[remove_ties_mask, "Tie"] = 0.0
129
 
130
  return preds
131
 
 
117
  remove_ties_mask = is_ot if not ignore_ties else torch.ones(len(preds), dtype=torch.bool)
118
  remove_ties_mask = remove_ties_mask.numpy()
119
  if remove_ties_mask.any():
120
+ tie_probs = preds.loc[remove_ties_mask, "Tie"]
121
  q = (1 - tie_probs)
122
  for c in preds.columns:
123
  if c.startswith("Blue") or c.startswith("Orange"):
124
+ preds.loc[remove_ties_mask, c] /= q
125
  if ignore_ties:
126
  preds = preds.drop("Tie", axis=1)
127
  else:
128
+ preds.loc[remove_ties_mask, "Tie"] = 0.0
129
 
130
  return preds
131