Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,8 @@ from tempfile import TemporaryDirectory
|
|
4 |
import numpy as np
|
5 |
import pandas as pd
|
6 |
import torch
|
7 |
-
from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard,
|
|
|
8 |
from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym
|
9 |
from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay
|
10 |
from tqdm import trange, tqdm
|
@@ -247,12 +248,15 @@ def gradio_app():
|
|
247 |
with TemporaryDirectory() as temp_dir:
|
248 |
with gr.Blocks() as demo:
|
249 |
gr.Markdown("# Next Goal Predictor")
|
250 |
-
gr.Markdown("Upload a replay file to get a plot of the next goal prediction
|
|
|
251 |
|
252 |
# Use gr.Column to stack components vertically
|
253 |
with gr.Column():
|
254 |
file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
|
255 |
-
checkboxes = gr.
|
|
|
|
|
256 |
submit_button = gr.Button("Generate Predictions")
|
257 |
plot_output = gr.Plot(label="Predictions")
|
258 |
download_button = gr.DownloadButton("Download Predictions", visible=False)
|
@@ -262,8 +266,14 @@ def gradio_app():
|
|
262 |
nullify_goal_difference = 0 in checkbox_options
|
263 |
ignore_ties = 1 in checkbox_options
|
264 |
print(f"Processing file: {replay_file}")
|
|
|
265 |
replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
267 |
if os.path.exists(preds_file):
|
268 |
print(f"Predictions file already exists: {preds_file}")
|
269 |
preds = pd.read_csv(preds_file)
|
|
|
4 |
import numpy as np
|
5 |
import pandas as pd
|
6 |
import torch
|
7 |
+
from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard, \
|
8 |
+
RF_SCOREBOARD, RF_EPISODE_SECONDS_REMAINING, SB_GAME_TIMER_SECONDS, SB_BLUE_SCORE, SB_ORANGE_SCORE
|
9 |
from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym
|
10 |
from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay
|
11 |
from tqdm import trange, tqdm
|
|
|
248 |
with TemporaryDirectory() as temp_dir:
|
249 |
with gr.Blocks() as demo:
|
250 |
gr.Markdown("# Next Goal Predictor")
|
251 |
+
gr.Markdown("Upload a replay file to get a plot of the next goal prediction.<br>"
|
252 |
+
"The model is trained on about 50 000 hours of SSL and RLCS replays in 1s, 2s, and 3s.")
|
253 |
|
254 |
# Use gr.Column to stack components vertically
|
255 |
with gr.Column():
|
256 |
file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
|
257 |
+
checkboxes = gr.Radio(label="Options", choices=["Nullify goal difference", "Ignore ties"], type="index",
|
258 |
+
info="You can choose to make the model think the goal difference is always 0 (so it doesn't prefer one team), "
|
259 |
+
"or make it pretend every situation is an overtime (e.g. ties are impossible).")
|
260 |
submit_button = gr.Button("Generate Predictions")
|
261 |
plot_output = gr.Plot(label="Predictions")
|
262 |
download_button = gr.DownloadButton("Download Predictions", visible=False)
|
|
|
266 |
nullify_goal_difference = 0 in checkbox_options
|
267 |
ignore_ties = 1 in checkbox_options
|
268 |
print(f"Processing file: {replay_file}")
|
269 |
+
|
270 |
replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
|
271 |
+
postfix = ""
|
272 |
+
if nullify_goal_difference:
|
273 |
+
postfix += "_nullify_goal_difference"
|
274 |
+
elif ignore_ties:
|
275 |
+
postfix += "_ignore_ties"
|
276 |
+
preds_file = os.path.join(temp_dir, f"predictions_{replay_stem}_{postfix}.csv")
|
277 |
if os.path.exists(preds_file):
|
278 |
print(f"Predictions file already exists: {preds_file}")
|
279 |
preds = pd.read_csv(preds_file)
|