Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,8 +4,9 @@ from tempfile import TemporaryDirectory
|
|
4 |
import numpy as np
|
5 |
import pandas as pd
|
6 |
import torch
|
|
|
7 |
from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard, \
|
8 |
-
|
9 |
from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym
|
10 |
from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay
|
11 |
from tqdm import trange, tqdm
|
@@ -14,8 +15,6 @@ os.chmod("/usr/local/lib/python3.10/site-packages/rlgym_tools/rocket_league/repl
|
|
14 |
|
15 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
|
17 |
-
from huggingface_hub import Repository
|
18 |
-
|
19 |
repo = Repository(local_dir="vortex-ngp", clone_from="Rolv-Arild/vortex-ngp", token=os.getenv("HF_TOKEN"))
|
20 |
repo.git_pull()
|
21 |
|
@@ -252,15 +251,18 @@ Upload a replay file to get a plot of the next goal prediction.
|
|
252 |
|
253 |
The model is trained on about 50,000 hours of SSL and RLCS replays in 1s, 2s, and 3s.
|
254 |
|
255 |
-
It predicts the probability of each team scoring the next goal, in addition to ties
|
256 |
-
It also predicts the probability that each team will score in 1s, 2s, etc. up to 60+ seconds
|
257 |
-
|
258 |
-
The plot only shows the team predictions, but you can download the full predictions.
|
259 |
""".strip()
|
260 |
|
261 |
RADIO_OPTIONS = [
|
|
|
|
|
262 |
"**Nullify goal difference**<br>Makes the model think the goal difference is always 0, so it doesn't have a bias towards one team.",
|
263 |
-
"
|
|
|
|
|
264 |
]
|
265 |
|
266 |
|
@@ -274,15 +276,15 @@ def gradio_app():
|
|
274 |
# Use gr.Column to stack components vertically
|
275 |
with gr.Column():
|
276 |
file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
|
277 |
-
checkboxes = gr.Radio(label="Options", choices=RADIO_OPTIONS, type="index")
|
278 |
submit_button = gr.Button("Generate Predictions")
|
279 |
plot_output = gr.Plot(label="Predictions")
|
280 |
download_button = gr.DownloadButton("Download Predictions", visible=False)
|
281 |
|
282 |
# Make plot on button click
|
283 |
def make_plot(replay_file, radio_option, progress=gr.Progress(track_tqdm=True)):
|
284 |
-
nullify_goal_difference = radio_option ==
|
285 |
-
ignore_ties = radio_option ==
|
286 |
print(f"Processing file: {replay_file}")
|
287 |
|
288 |
replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
|
|
|
4 |
import numpy as np
|
5 |
import pandas as pd
|
6 |
import torch
|
7 |
+
from huggingface_hub import Repository
|
8 |
from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard, \
|
9 |
+
SB_GAME_TIMER_SECONDS, SB_BLUE_SCORE, SB_ORANGE_SCORE
|
10 |
from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym
|
11 |
from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay
|
12 |
from tqdm import trange, tqdm
|
|
|
15 |
|
16 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
|
|
|
|
|
18 |
repo = Repository(local_dir="vortex-ngp", clone_from="Rolv-Arild/vortex-ngp", token=os.getenv("HF_TOKEN"))
|
19 |
repo.git_pull()
|
20 |
|
|
|
251 |
|
252 |
The model is trained on about 50,000 hours of SSL and RLCS replays in 1s, 2s, and 3s.
|
253 |
|
254 |
+
It predicts the probability of each team scoring the next goal, in addition to ties.<br>
|
255 |
+
It also predicts the probability that each team will score in 1s, 2s, etc. up to 60+ seconds.<br>
|
256 |
+
The plot only shows the team predictions, but you can download the full predictions if you want.
|
|
|
257 |
""".strip()
|
258 |
|
259 |
RADIO_OPTIONS = [
|
260 |
+
("**Default**<br>Uses the model as it is trained, with no modifications.", "default"),
|
261 |
+
(
|
262 |
"**Nullify goal difference**<br>Makes the model think the goal difference is always 0, so it doesn't have a bias towards one team.",
|
263 |
+
"nullify_goal_difference"),
|
264 |
+
("**Ignore ties**<br>Makes the model pretend every situation is an overtime (e.g. ties are impossible).",
|
265 |
+
"ignore_ties"),
|
266 |
]
|
267 |
|
268 |
|
|
|
276 |
# Use gr.Column to stack components vertically
|
277 |
with gr.Column():
|
278 |
file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
|
279 |
+
checkboxes = gr.Radio(label="Options", choices=RADIO_OPTIONS, type="index", value=0)
|
280 |
submit_button = gr.Button("Generate Predictions")
|
281 |
plot_output = gr.Plot(label="Predictions")
|
282 |
download_button = gr.DownloadButton("Download Predictions", visible=False)
|
283 |
|
284 |
# Make plot on button click
|
285 |
def make_plot(replay_file, radio_option, progress=gr.Progress(track_tqdm=True)):
|
286 |
+
nullify_goal_difference = radio_option == 1
|
287 |
+
ignore_ties = radio_option == 2
|
288 |
print(f"Processing file: {replay_file}")
|
289 |
|
290 |
replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
|