Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ from tempfile import TemporaryDirectory
|
|
4 |
import numpy as np
|
5 |
import pandas as pd
|
6 |
import torch
|
7 |
-
from rlgym_tools.rocket_league.misc.serialize import
|
8 |
from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym
|
9 |
from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay
|
10 |
from tqdm import trange, tqdm
|
@@ -26,7 +26,6 @@ MODEL.eval()
|
|
26 |
@torch.inference_mode()
|
27 |
def infer(model, replay_file,
|
28 |
nullify_goal_difference=False,
|
29 |
-
predict_with_all_quadrants=True,
|
30 |
ignore_ties=False):
|
31 |
num_outputs = 123
|
32 |
swap_team_idx = torch.arange(num_outputs)
|
@@ -37,20 +36,24 @@ def infer(model, replay_file,
|
|
37 |
replay = ParsedReplay.load(replay_file)
|
38 |
it = tqdm(replay_to_rlgym(replay), desc="Loading replay", total=len(replay.game_df))
|
39 |
replay_frames = []
|
40 |
-
|
|
|
|
|
41 |
for replay_frame in it:
|
42 |
replay_frames.append(replay_frame)
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
47 |
it.close()
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
timer = scoreboard[:, SB_GAME_TIMER_SECONDS]
|
52 |
is_ot = timer > 450
|
53 |
-
ot_time_remaining =
|
54 |
if len(ot_time_remaining) > 0:
|
55 |
ot_timer = ot_time_remaining[0] - ot_time_remaining
|
56 |
timer[is_ot] = -ot_timer # Negate to indicate overtime
|
@@ -63,9 +66,10 @@ def infer(model, replay_file,
|
|
63 |
|
64 |
bs = 512
|
65 |
predictions = []
|
66 |
-
it = trange(len(
|
67 |
-
for i in range(0, len(
|
68 |
-
batch =
|
|
|
69 |
if nullify_goal_difference:
|
70 |
batch[:, SB_BLUE_SCORE] = 0
|
71 |
batch[:, SB_ORANGE_SCORE] = 0
|
|
|
4 |
import numpy as np
|
5 |
import pandas as pd
|
6 |
import torch
|
7 |
+
from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard, RF_SCOREBOARD, RF_EPISODE_SECONDS_REMAINING, SB_GAME_TIMER_SECONDS, SB_BLUE_SCORE, SB_ORANGE_SCORE
|
8 |
from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym
|
9 |
from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay
|
10 |
from tqdm import trange, tqdm
|
|
|
26 |
@torch.inference_mode()
|
27 |
def infer(model, replay_file,
|
28 |
nullify_goal_difference=False,
|
|
|
29 |
ignore_ties=False):
|
30 |
num_outputs = 123
|
31 |
swap_team_idx = torch.arange(num_outputs)
|
|
|
36 |
replay = ParsedReplay.load(replay_file)
|
37 |
it = tqdm(replay_to_rlgym(replay), desc="Loading replay", total=len(replay.game_df))
|
38 |
replay_frames = []
|
39 |
+
serialized_states = []
|
40 |
+
serialized_scoreboards = []
|
41 |
+
seconds_remaining = []
|
42 |
for replay_frame in it:
|
43 |
replay_frames.append(replay_frame)
|
44 |
+
sstate = serialize_game_state(replay_frame.state)
|
45 |
+
sscoreboard = serialize_scoreboard(replay_frame.scoreboard)
|
46 |
+
serialized_states.append(sstate)
|
47 |
+
serialized_scoreboards.append(sscoreboard)
|
48 |
+
seconds_remaining.append(replay_frame.episode_seconds_remaining)
|
49 |
+
serialized_states = torch.from_numpy(np.stack(serialized_states))
|
50 |
+
serialized_scoreboards = torch.from_numpy(np.stack(serialized_scoreboards))
|
51 |
+
seconds_remaining = torch.tensor(seconds_remaining)
|
52 |
it.close()
|
53 |
|
54 |
+
timer = serialized_scoreboards[:, SB_GAME_TIMER_SECONDS]
|
|
|
|
|
55 |
is_ot = timer > 450
|
56 |
+
ot_time_remaining = seconds_remaining[is_ot]
|
57 |
if len(ot_time_remaining) > 0:
|
58 |
ot_timer = ot_time_remaining[0] - ot_time_remaining
|
59 |
timer[is_ot] = -ot_timer # Negate to indicate overtime
|
|
|
66 |
|
67 |
bs = 512
|
68 |
predictions = []
|
69 |
+
it = trange(len(serialized_states), desc="Running model")
|
70 |
+
for i in range(0, len(serialized_states), bs):
|
71 |
+
batch = (serialized_states[i:i + bs].clone().to(DEVICE),
|
72 |
+
scoreboard[i:i + bs].clone().to(DEVICE))
|
73 |
if nullify_goal_difference:
|
74 |
batch[:, SB_BLUE_SCORE] = 0
|
75 |
batch[:, SB_ORANGE_SCORE] = 0
|