Rolv-Arild commited on
Commit
c072f66
·
verified ·
1 Parent(s): c97d36a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -53
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import os
2
  from tempfile import TemporaryDirectory
3
 
 
4
  import numpy as np
5
  import pandas as pd
 
6
  import torch
7
  from huggingface_hub import Repository
8
  from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard, \
@@ -23,6 +25,7 @@ MODEL = torch.jit.load("vortex-ngp/vortex-ngp-avid-frog.pt", map_location=DEVICE
23
  MODEL.eval()
24
 
25
 
 
26
  @torch.inference_mode()
27
  def infer(model, replay_file,
28
  nullify_goal_difference=False,
@@ -263,60 +266,56 @@ RADIO_INFO = """
263
  - **Ignore ties**: Makes the model pretend every situation is an overtime (e.g. ties are impossible).
264
  """.strip()
265
 
 
 
 
266
 
267
- def gradio_app():
268
- import gradio as gr
 
 
 
 
 
 
269
 
270
- with TemporaryDirectory() as temp_dir:
271
- with gr.Blocks() as demo:
272
- gr.Markdown(DESCRIPTION)
273
-
274
- # Use gr.Column to stack components vertically
275
- with gr.Column():
276
- file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
277
- checkboxes = gr.Radio(label="Options", choices=RADIO_OPTIONS, type="index", value=RADIO_OPTIONS[0],
278
- info=RADIO_INFO)
279
- submit_button = gr.Button("Generate Predictions")
280
- plot_output = gr.Plot(label="Predictions")
281
- download_button = gr.DownloadButton("Download Predictions", visible=False)
282
 
 
283
  # Make plot on button click
284
- def make_plot(replay_file, radio_option, progress=gr.Progress(track_tqdm=True)):
285
- nullify_goal_difference = radio_option == 1
286
- ignore_ties = radio_option == 2
287
- print(f"Processing file: {replay_file}")
288
-
289
- replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
290
- postfix = ""
291
- if nullify_goal_difference:
292
- postfix += "_nullify_goal_difference"
293
- elif ignore_ties:
294
- postfix += "_ignore_ties"
295
- preds_file = os.path.join(temp_dir, f"predictions_{replay_stem}{postfix}.csv")
296
- if os.path.exists(preds_file):
297
- print(f"Predictions file already exists: {preds_file}")
298
- preds = pd.read_csv(preds_file, dtype={"Touch": str})
299
- else:
300
- preds = infer(MODEL, replay_file,
301
- nullify_goal_difference=nullify_goal_difference,
302
- ignore_ties=ignore_ties)
303
- plt = plot_plotly(preds)
304
- print(f"Plot generated for file: {replay_file}")
305
- preds.to_csv(preds_file)
306
- if len(os.listdir(temp_dir)) > 100:
307
- # Delete least recent file
308
- oldest_file = min(os.listdir(temp_dir), key=lambda f: os.path.getctime(os.path.join(temp_dir, f)))
309
- os.remove(os.path.join(temp_dir, oldest_file))
310
- return plt, gr.DownloadButton(value=preds_file, visible=True)
311
-
312
- submit_button.click(
313
- fn=make_plot,
314
- inputs=[file_input, checkboxes],
315
- outputs=[plot_output, download_button],
316
- show_progress="full",
317
- )
318
- demo.launch()
319
-
320
-
321
- if __name__ == '__main__':
322
- gradio_app()
 
1
  import os
2
  from tempfile import TemporaryDirectory
3
 
4
+ import gradio as gr
5
  import numpy as np
6
  import pandas as pd
7
+ import spaces
8
  import torch
9
  from huggingface_hub import Repository
10
  from rlgym_tools.rocket_league.misc.serialize import serialize_game_state, serialize_scoreboard, \
 
25
  MODEL.eval()
26
 
27
 
28
+ @spaces.GPU
29
  @torch.inference_mode()
30
  def infer(model, replay_file,
31
  nullify_goal_difference=False,
 
266
  - **Ignore ties**: Makes the model pretend every situation is an overtime (e.g. ties are impossible).
267
  """.strip()
268
 
269
+ with TemporaryDirectory() as temp_dir:
270
+ with gr.Blocks() as demo:
271
+ gr.Markdown(DESCRIPTION)
272
 
273
+ # Use gr.Column to stack components vertically
274
+ with gr.Column():
275
+ file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
276
+ checkboxes = gr.Radio(label="Options", choices=RADIO_OPTIONS, type="index", value=RADIO_OPTIONS[0],
277
+ info=RADIO_INFO)
278
+ submit_button = gr.Button("Generate Predictions")
279
+ plot_output = gr.Plot(label="Predictions")
280
+ download_button = gr.DownloadButton("Download Predictions", visible=False)
281
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
+ def make_plot(replay_file, radio_option, progress=gr.Progress(track_tqdm=True)):
284
  # Make plot on button click
285
+ nullify_goal_difference = radio_option == 1
286
+ ignore_ties = radio_option == 2
287
+ print(f"Processing file: {replay_file}")
288
+
289
+ replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
290
+ postfix = ""
291
+ if nullify_goal_difference:
292
+ postfix += "_nullify_goal_difference"
293
+ elif ignore_ties:
294
+ postfix += "_ignore_ties"
295
+ preds_file = os.path.join(temp_dir, f"predictions_{replay_stem}{postfix}.csv")
296
+ if os.path.exists(preds_file):
297
+ print(f"Predictions file already exists: {preds_file}")
298
+ preds = pd.read_csv(preds_file, dtype={"Touch": str})
299
+ else:
300
+ preds = infer(MODEL, replay_file,
301
+ nullify_goal_difference=nullify_goal_difference,
302
+ ignore_ties=ignore_ties)
303
+ plt = plot_plotly(preds)
304
+ print(f"Plot generated for file: {replay_file}")
305
+ preds.to_csv(preds_file)
306
+ if len(os.listdir(temp_dir)) > 100:
307
+ # Delete least recent file
308
+ oldest_file = min(os.listdir(temp_dir), key=lambda f: os.path.getctime(os.path.join(temp_dir, f)))
309
+ os.remove(os.path.join(temp_dir, oldest_file))
310
+ return plt, gr.DownloadButton(value=preds_file, visible=True)
311
+
312
+
313
+ submit_button.click(
314
+ fn=make_plot,
315
+ inputs=[file_input, checkboxes],
316
+ outputs=[plot_output, download_button],
317
+ show_progress="full",
318
+ )
319
+
320
+ demo.queue(default_concurrency_limit=None)
321
+ demo.launch()