Rolv-Arild commited on
Commit
d8f30c7
·
verified ·
1 Parent(s): 204b85d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -4,8 +4,9 @@ from tempfile import TemporaryDirectory
4
  import numpy as np
5
  import pandas as pd
6
  import torch
 
7
  from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard, \
8
- RF_SCOREBOARD, RF_EPISODE_SECONDS_REMAINING, SB_GAME_TIMER_SECONDS, SB_BLUE_SCORE, SB_ORANGE_SCORE
9
  from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym
10
  from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay
11
  from tqdm import trange, tqdm
@@ -14,8 +15,6 @@ os.chmod("/usr/local/lib/python3.10/site-packages/rlgym_tools/rocket_league/repl
14
 
15
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
- from huggingface_hub import Repository
18
-
19
  repo = Repository(local_dir="vortex-ngp", clone_from="Rolv-Arild/vortex-ngp", token=os.getenv("HF_TOKEN"))
20
  repo.git_pull()
21
 
@@ -252,15 +251,18 @@ Upload a replay file to get a plot of the next goal prediction.
252
 
253
  The model is trained on about 50,000 hours of SSL and RLCS replays in 1s, 2s, and 3s.
254
 
255
- It predicts the probability of each team scoring the next goal, in addition to ties.
256
- It also predicts the probability that each team will score in 1s, 2s, etc. up to 60+ seconds.
257
-
258
- The plot only shows the team predictions, but you can download the full predictions.
259
  """.strip()
260
 
261
  RADIO_OPTIONS = [
 
 
262
  "**Nullify goal difference**<br>Makes the model think the goal difference is always 0, so it doesn't have a bias towards one team.",
263
- "**Ignore ties**<br>Makes the model pretend every situation is an overtime (e.g. ties are impossible)."
 
 
264
  ]
265
 
266
 
@@ -274,15 +276,15 @@ def gradio_app():
274
  # Use gr.Column to stack components vertically
275
  with gr.Column():
276
  file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
277
- checkboxes = gr.Radio(label="Options", choices=RADIO_OPTIONS, type="index")
278
  submit_button = gr.Button("Generate Predictions")
279
  plot_output = gr.Plot(label="Predictions")
280
  download_button = gr.DownloadButton("Download Predictions", visible=False)
281
 
282
  # Make plot on button click
283
  def make_plot(replay_file, radio_option, progress=gr.Progress(track_tqdm=True)):
284
- nullify_goal_difference = radio_option == 0
285
- ignore_ties = radio_option == 1
286
  print(f"Processing file: {replay_file}")
287
 
288
  replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
 
4
  import numpy as np
5
  import pandas as pd
6
  import torch
7
+ from huggingface_hub import Repository
8
  from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard, \
9
+ SB_GAME_TIMER_SECONDS, SB_BLUE_SCORE, SB_ORANGE_SCORE
10
  from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym
11
  from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay
12
  from tqdm import trange, tqdm
 
15
 
16
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
 
 
18
  repo = Repository(local_dir="vortex-ngp", clone_from="Rolv-Arild/vortex-ngp", token=os.getenv("HF_TOKEN"))
19
  repo.git_pull()
20
 
 
251
 
252
  The model is trained on about 50,000 hours of SSL and RLCS replays in 1s, 2s, and 3s.
253
 
254
+ It predicts the probability of each team scoring the next goal, in addition to ties.<br>
255
+ It also predicts the probability that each team will score in 1s, 2s, etc. up to 60+ seconds.<br>
256
+ The plot only shows the team predictions, but you can download the full predictions if you want.
 
257
  """.strip()
258
 
259
  RADIO_OPTIONS = [
260
+ ("**Default**<br>Uses the model as it is trained, with no modifications.", "default"),
261
+ (
262
  "**Nullify goal difference**<br>Makes the model think the goal difference is always 0, so it doesn't have a bias towards one team.",
263
+ "nullify_goal_difference"),
264
+ ("**Ignore ties**<br>Makes the model pretend every situation is an overtime (e.g. ties are impossible).",
265
+ "ignore_ties"),
266
  ]
267
 
268
 
 
276
  # Use gr.Column to stack components vertically
277
  with gr.Column():
278
  file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
279
+ checkboxes = gr.Radio(label="Options", choices=RADIO_OPTIONS, type="index", value=0)
280
  submit_button = gr.Button("Generate Predictions")
281
  plot_output = gr.Plot(label="Predictions")
282
  download_button = gr.DownloadButton("Download Predictions", visible=False)
283
 
284
  # Make plot on button click
285
  def make_plot(replay_file, radio_option, progress=gr.Progress(track_tqdm=True)):
286
+ nullify_goal_difference = radio_option == 1
287
+ ignore_ties = radio_option == 2
288
  print(f"Processing file: {replay_file}")
289
 
290
  replay_stem = os.path.splitext(os.path.basename(replay_file))[0]