Rolv-Arild commited on
Commit
dbb5b76
·
verified ·
1 Parent(s): e98dedf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -4
app.py CHANGED
@@ -4,7 +4,8 @@ from tempfile import TemporaryDirectory
4
  import numpy as np
5
  import pandas as pd
6
  import torch
7
- from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard, RF_SCOREBOARD, RF_EPISODE_SECONDS_REMAINING, SB_GAME_TIMER_SECONDS, SB_BLUE_SCORE, SB_ORANGE_SCORE
 
8
  from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym
9
  from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay
10
  from tqdm import trange, tqdm
@@ -247,12 +248,15 @@ def gradio_app():
247
  with TemporaryDirectory() as temp_dir:
248
  with gr.Blocks() as demo:
249
  gr.Markdown("# Next Goal Predictor")
250
- gr.Markdown("Upload a replay file to get a plot of the next goal prediction.")
 
251
 
252
  # Use gr.Column to stack components vertically
253
  with gr.Column():
254
  file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
255
- checkboxes = gr.CheckboxGroup(label="Options", choices=["Nullify goal difference", "Ignore ties"], type="index")
 
 
256
  submit_button = gr.Button("Generate Predictions")
257
  plot_output = gr.Plot(label="Predictions")
258
  download_button = gr.DownloadButton("Download Predictions", visible=False)
@@ -262,8 +266,14 @@ def gradio_app():
262
  nullify_goal_difference = 0 in checkbox_options
263
  ignore_ties = 1 in checkbox_options
264
  print(f"Processing file: {replay_file}")
 
265
  replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
266
- preds_file = os.path.join(temp_dir, f"predictions_{replay_stem}_{nullify_goal_difference=}_{ignore_ties=}.csv")
 
 
 
 
 
267
  if os.path.exists(preds_file):
268
  print(f"Predictions file already exists: {preds_file}")
269
  preds = pd.read_csv(preds_file)
 
4
  import numpy as np
5
  import pandas as pd
6
  import torch
7
+ from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard, \
8
+ RF_SCOREBOARD, RF_EPISODE_SECONDS_REMAINING, SB_GAME_TIMER_SECONDS, SB_BLUE_SCORE, SB_ORANGE_SCORE
9
  from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym
10
  from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay
11
  from tqdm import trange, tqdm
 
248
  with TemporaryDirectory() as temp_dir:
249
  with gr.Blocks() as demo:
250
  gr.Markdown("# Next Goal Predictor")
251
+ gr.Markdown("Upload a replay file to get a plot of the next goal prediction.<br>"
252
+ "The model is trained on about 50 000 hours of SSL and RLCS replays in 1s, 2s, and 3s.")
253
 
254
  # Use gr.Column to stack components vertically
255
  with gr.Column():
256
  file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
257
+ checkboxes = gr.Radio(label="Options", choices=["Nullify goal difference", "Ignore ties"], type="index",
258
+ info="You can choose to make the model think the goal difference is always 0 (so it doesn't prefer one team), "
259
+ "or make it pretend every situation is an overtime (e.g. ties are impossible).")
260
  submit_button = gr.Button("Generate Predictions")
261
  plot_output = gr.Plot(label="Predictions")
262
  download_button = gr.DownloadButton("Download Predictions", visible=False)
 
266
  nullify_goal_difference = 0 in checkbox_options
267
  ignore_ties = 1 in checkbox_options
268
  print(f"Processing file: {replay_file}")
269
+
270
  replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
271
+ postfix = ""
272
+ if nullify_goal_difference:
273
+ postfix += "_nullify_goal_difference"
274
+ elif ignore_ties:
275
+ postfix += "_ignore_ties"
276
+ preds_file = os.path.join(temp_dir, f"predictions_{replay_stem}_{postfix}.csv")
277
  if os.path.exists(preds_file):
278
  print(f"Predictions file already exists: {preds_file}")
279
  preds = pd.read_csv(preds_file)