Rolv-Arild commited on
Commit
eead613
·
verified ·
1 Parent(s): 7708de2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -5
app.py CHANGED
@@ -58,18 +58,15 @@ def infer(model, replay_file,
58
  ot_timer = ot_time_remaining[0] - ot_time_remaining
59
  timer[is_ot] = -ot_timer # Negate to indicate overtime
60
 
61
- goal_diff = scoreboard[:, SB_BLUE_SCORE] - scoreboard[:, SB_ORANGE_SCORE]
62
  goal_diff_diff = goal_diff.diff(prepend=torch.Tensor([0]))
63
 
64
- # if nullify_goal_difference:
65
- # features.scoreboard[..., 2] = 0
66
-
67
  bs = 512
68
  predictions = []
69
  it = trange(len(serialized_states), desc="Running model")
70
  for i in range(0, len(serialized_states), bs):
71
  batch = (serialized_states[i:i + bs].clone().to(DEVICE),
72
- scoreboard[i:i + bs].clone().to(DEVICE))
73
  if nullify_goal_difference:
74
  batch[:, SB_BLUE_SCORE] = 0
75
  batch[:, SB_ORANGE_SCORE] = 0
 
58
  ot_timer = ot_time_remaining[0] - ot_time_remaining
59
  timer[is_ot] = -ot_timer # Negate to indicate overtime
60
 
61
+ goal_diff = serialized_scoreboards[:, SB_BLUE_SCORE] - serialized_scoreboards[:, SB_ORANGE_SCORE]
62
  goal_diff_diff = goal_diff.diff(prepend=torch.Tensor([0]))
63
 
 
 
 
64
  bs = 512
65
  predictions = []
66
  it = trange(len(serialized_states), desc="Running model")
67
  for i in range(0, len(serialized_states), bs):
68
  batch = (serialized_states[i:i + bs].clone().to(DEVICE),
69
+ serialized_scoreboards[i:i + bs].clone().to(DEVICE))
70
  if nullify_goal_difference:
71
  batch[:, SB_BLUE_SCORE] = 0
72
  batch[:, SB_ORANGE_SCORE] = 0