Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -58,18 +58,15 @@ def infer(model, replay_file,
|
|
58 |
ot_timer = ot_time_remaining[0] - ot_time_remaining
|
59 |
timer[is_ot] = -ot_timer # Negate to indicate overtime
|
60 |
|
61 |
-
goal_diff =
|
62 |
goal_diff_diff = goal_diff.diff(prepend=torch.Tensor([0]))
|
63 |
|
64 |
-
# if nullify_goal_difference:
|
65 |
-
# features.scoreboard[..., 2] = 0
|
66 |
-
|
67 |
bs = 512
|
68 |
predictions = []
|
69 |
it = trange(len(serialized_states), desc="Running model")
|
70 |
for i in range(0, len(serialized_states), bs):
|
71 |
batch = (serialized_states[i:i + bs].clone().to(DEVICE),
|
72 |
-
|
73 |
if nullify_goal_difference:
|
74 |
batch[:, SB_BLUE_SCORE] = 0
|
75 |
batch[:, SB_ORANGE_SCORE] = 0
|
|
|
58 |
ot_timer = ot_time_remaining[0] - ot_time_remaining
|
59 |
timer[is_ot] = -ot_timer # Negate to indicate overtime
|
60 |
|
61 |
+
goal_diff = serialized_scoreboards[:, SB_BLUE_SCORE] - serialized_scoreboards[:, SB_ORANGE_SCORE]
|
62 |
goal_diff_diff = goal_diff.diff(prepend=torch.Tensor([0]))
|
63 |
|
|
|
|
|
|
|
64 |
bs = 512
|
65 |
predictions = []
|
66 |
it = trange(len(serialized_states), desc="Running model")
|
67 |
for i in range(0, len(serialized_states), bs):
|
68 |
batch = (serialized_states[i:i + bs].clone().to(DEVICE),
|
69 |
+
serialized_scoreboards[i:i + bs].clone().to(DEVICE))
|
70 |
if nullify_goal_difference:
|
71 |
batch[:, SB_BLUE_SCORE] = 0
|
72 |
batch[:, SB_ORANGE_SCORE] = 0
|