Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -26,10 +26,11 @@ def infer(model, replay_file,
|
|
26 |
nullify_goal_difference=False,
|
27 |
predict_with_all_quadrants=True,
|
28 |
ignore_ties=False):
|
29 |
-
|
30 |
-
|
|
|
31 |
swap_team_idx[mid:-1] = swap_team_idx[:mid]
|
32 |
-
swap_team_idx[:mid] +=
|
33 |
|
34 |
replay = ParsedReplay.load(replay_file)
|
35 |
it = tqdm(replay_to_rlgym(replay), desc="Loading replay", total=len(replay.game_df))
|
@@ -73,7 +74,7 @@ def infer(model, replay_file,
|
|
73 |
predictions = torch.cat(predictions, dim=0)
|
74 |
probs = predictions.softmax(dim=-1)
|
75 |
|
76 |
-
bin_seconds = torch.linspace(0, 60,
|
77 |
class_names = [
|
78 |
f"{t}: {s:g}s" for t in ["Blue", "Orange"]
|
79 |
for s in bin_seconds.tolist()
|
|
|
26 |
nullify_goal_difference=False,
|
27 |
predict_with_all_quadrants=True,
|
28 |
ignore_ties=False):
|
29 |
+
num_outputs = 123
|
30 |
+
swap_team_idx = torch.arange(num_outputs)
|
31 |
+
mid = num_outputs // 2
|
32 |
swap_team_idx[mid:-1] = swap_team_idx[:mid]
|
33 |
+
swap_team_idx[:mid] += num_outputs // 2
|
34 |
|
35 |
replay = ParsedReplay.load(replay_file)
|
36 |
it = tqdm(replay_to_rlgym(replay), desc="Loading replay", total=len(replay.game_df))
|
|
|
74 |
predictions = torch.cat(predictions, dim=0)
|
75 |
probs = predictions.softmax(dim=-1)
|
76 |
|
77 |
+
bin_seconds = torch.linspace(0, 60, num_outputs // 2)
|
78 |
class_names = [
|
79 |
f"{t}: {s:g}s" for t in ["Blue", "Orange"]
|
80 |
for s in bin_seconds.tolist()
|