Rolv-Arild commited on
Commit
355e78e
·
verified ·
1 Parent(s): d87e40e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -26,10 +26,11 @@ def infer(model, replay_file,
26
  nullify_goal_difference=False,
27
  predict_with_all_quadrants=True,
28
  ignore_ties=False):
29
- swap_team_idx = torch.arange(model.num_outputs)
30
- mid = model.num_outputs // 2
 
31
  swap_team_idx[mid:-1] = swap_team_idx[:mid]
32
- swap_team_idx[:mid] += model.num_outputs // 2
33
 
34
  replay = ParsedReplay.load(replay_file)
35
  it = tqdm(replay_to_rlgym(replay), desc="Loading replay", total=len(replay.game_df))
@@ -73,7 +74,7 @@ def infer(model, replay_file,
73
  predictions = torch.cat(predictions, dim=0)
74
  probs = predictions.softmax(dim=-1)
75
 
76
- bin_seconds = torch.linspace(0, 60, model.num_outputs // 2)
77
  class_names = [
78
  f"{t}: {s:g}s" for t in ["Blue", "Orange"]
79
  for s in bin_seconds.tolist()
 
26
  nullify_goal_difference=False,
27
  predict_with_all_quadrants=True,
28
  ignore_ties=False):
29
+ num_outputs = 123
30
+ swap_team_idx = torch.arange(num_outputs)
31
+ mid = num_outputs // 2
32
  swap_team_idx[mid:-1] = swap_team_idx[:mid]
33
+ swap_team_idx[:mid] += num_outputs // 2
34
 
35
  replay = ParsedReplay.load(replay_file)
36
  it = tqdm(replay_to_rlgym(replay), desc="Loading replay", total=len(replay.game_df))
 
74
  predictions = torch.cat(predictions, dim=0)
75
  probs = predictions.softmax(dim=-1)
76
 
77
+ bin_seconds = torch.linspace(0, 60, num_outputs // 2)
78
  class_names = [
79
  f"{t}: {s:g}s" for t in ["Blue", "Orange"]
80
  for s in bin_seconds.tolist()