Rolv-Arild commited on
Commit
7708de2
·
verified ·
1 Parent(s): b1be6eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -4,7 +4,7 @@ from tempfile import TemporaryDirectory
4
  import numpy as np
5
  import pandas as pd
6
  import torch
7
- from rlgym_tools.rocket_league.misc.serialize import serialize_replay_frame, RF_SCOREBOARD, RF_EPISODE_SECONDS_REMAINING, SB_GAME_TIMER_SECONDS, SB_BLUE_SCORE, SB_ORANGE_SCORE
8
  from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym
9
  from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay
10
  from tqdm import trange, tqdm
@@ -26,7 +26,6 @@ MODEL.eval()
26
  @torch.inference_mode()
27
  def infer(model, replay_file,
28
  nullify_goal_difference=False,
29
- predict_with_all_quadrants=True,
30
  ignore_ties=False):
31
  num_outputs = 123
32
  swap_team_idx = torch.arange(num_outputs)
@@ -37,20 +36,24 @@ def infer(model, replay_file,
37
  replay = ParsedReplay.load(replay_file)
38
  it = tqdm(replay_to_rlgym(replay), desc="Loading replay", total=len(replay.game_df))
39
  replay_frames = []
40
- serialized = []
 
 
41
  for replay_frame in it:
42
  replay_frames.append(replay_frame)
43
- s = serialize_replay_frame(replay_frame)
44
- serialized.append(s)
45
- serialized = np.stack(serialized)
46
- serialized = torch.from_numpy(serialized)
 
 
 
 
47
  it.close()
48
 
49
- scoreboard = serialized[:, RF_SCOREBOARD]
50
-
51
- timer = scoreboard[:, SB_GAME_TIMER_SECONDS]
52
  is_ot = timer > 450
53
- ot_time_remaining = serialized[is_ot, RF_EPISODE_SECONDS_REMAINING]
54
  if len(ot_time_remaining) > 0:
55
  ot_timer = ot_time_remaining[0] - ot_time_remaining
56
  timer[is_ot] = -ot_timer # Negate to indicate overtime
@@ -63,9 +66,10 @@ def infer(model, replay_file,
63
 
64
  bs = 512
65
  predictions = []
66
- it = trange(len(serialized), desc="Running model")
67
- for i in range(0, len(serialized), bs):
68
- batch = serialized[i:i + bs].clone().to(DEVICE)
 
69
  if nullify_goal_difference:
70
  batch[:, SB_BLUE_SCORE] = 0
71
  batch[:, SB_ORANGE_SCORE] = 0
 
4
  import numpy as np
5
  import pandas as pd
6
  import torch
7
+ from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard, RF_SCOREBOARD, RF_EPISODE_SECONDS_REMAINING, SB_GAME_TIMER_SECONDS, SB_BLUE_SCORE, SB_ORANGE_SCORE
8
  from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym
9
  from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay
10
  from tqdm import trange, tqdm
 
26
  @torch.inference_mode()
27
  def infer(model, replay_file,
28
  nullify_goal_difference=False,
 
29
  ignore_ties=False):
30
  num_outputs = 123
31
  swap_team_idx = torch.arange(num_outputs)
 
36
  replay = ParsedReplay.load(replay_file)
37
  it = tqdm(replay_to_rlgym(replay), desc="Loading replay", total=len(replay.game_df))
38
  replay_frames = []
39
+ serialized_states = []
40
+ serialized_scoreboards = []
41
+ seconds_remaining = []
42
  for replay_frame in it:
43
  replay_frames.append(replay_frame)
44
+ sstate = serialize_game_state(replay_frame.state)
45
+ sscoreboard = serialize_scoreboard(replay_frame.scoreboard)
46
+ serialized_states.append(sstate)
47
+ serialized_scoreboards.append(sscoreboard)
48
+ seconds_remaining.append(replay_frame.episode_seconds_remaining)
49
+ serialized_states = torch.from_numpy(np.stack(serialized_states))
50
+ serialized_scoreboards = torch.from_numpy(np.stack(serialized_scoreboards))
51
+ seconds_remaining = torch.tensor(seconds_remaining)
52
  it.close()
53
 
54
+ timer = serialized_scoreboards[:, SB_GAME_TIMER_SECONDS]
 
 
55
  is_ot = timer > 450
56
+ ot_time_remaining = seconds_remaining[is_ot]
57
  if len(ot_time_remaining) > 0:
58
  ot_timer = ot_time_remaining[0] - ot_time_remaining
59
  timer[is_ot] = -ot_timer # Negate to indicate overtime
 
66
 
67
  bs = 512
68
  predictions = []
69
+ it = trange(len(serialized_states), desc="Running model")
70
+ for i in range(0, len(serialized_states), bs):
71
+ batch = (serialized_states[i:i + bs].clone().to(DEVICE),
72
+ scoreboard[i:i + bs].clone().to(DEVICE))
73
  if nullify_goal_difference:
74
  batch[:, SB_BLUE_SCORE] = 0
75
  batch[:, SB_ORANGE_SCORE] = 0