Rolv-Arild commited on
Commit
dce44ae
·
verified ·
1 Parent(s): 9a6da2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -14
app.py CHANGED
@@ -66,7 +66,7 @@ def infer(model, replay_file,
66
  predictions = []
67
  it = trange(len(serialized_states), desc="Running model")
68
  for i in range(0, len(serialized_states), bs):
69
- batch = (serialized_states[i:i + bs].clone().to(DEVICE),
70
  serialized_scoreboards[i:i + bs].clone().to(DEVICE))
71
  if nullify_goal_difference or ignore_ties:
72
  batch[1][:, SB_BLUE_SCORE] = 0
@@ -115,13 +115,17 @@ def infer(model, replay_file,
115
  preds = preds[main_cols + [c for c in preds.columns if c not in main_cols]]
116
  # Set index name
117
  preds.index.name = "Frame"
118
- if ignore_ties:
119
- tie_probs = preds["Tie"]
 
120
  q = (1 - tie_probs)
121
- preds = preds.drop("Tie", axis=1)
122
  for c in preds.columns:
123
  if c.startswith("Blue") or c.startswith("Orange"):
124
- preds[c] /= q
 
 
 
 
125
 
126
  return preds
127
 
@@ -242,21 +246,36 @@ def plot_plotly(preds: pd.DataFrame):
242
  return fig
243
 
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  def gradio_app():
246
  import gradio as gr
247
 
248
  with TemporaryDirectory() as temp_dir:
249
  with gr.Blocks() as demo:
250
- gr.Markdown("# Next Goal Predictor")
251
- gr.Markdown("Upload a replay file to get a plot of the next goal prediction.<br>"
252
- "The model is trained on about 50 000 hours of SSL and RLCS replays in 1s, 2s, and 3s.")
253
 
254
  # Use gr.Column to stack components vertically
255
  with gr.Column():
256
  file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
257
  checkboxes = gr.Radio(label="Options", choices=["Nullify goal difference", "Ignore ties"], type="index",
258
- info="You can choose to make the model think the goal difference is always 0 (so it doesn't prefer one team), "
259
- "or make it pretend every situation is an overtime (e.g. ties are impossible).")
260
  submit_button = gr.Button("Generate Predictions")
261
  plot_output = gr.Plot(label="Predictions")
262
  download_button = gr.DownloadButton("Download Predictions", visible=False)
@@ -266,7 +285,7 @@ def gradio_app():
266
  nullify_goal_difference = radio_option == 0
267
  ignore_ties = radio_option == 1
268
  print(f"Processing file: {replay_file}")
269
-
270
  replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
271
  postfix = ""
272
  if nullify_goal_difference:
@@ -278,9 +297,9 @@ def gradio_app():
278
  print(f"Predictions file already exists: {preds_file}")
279
  preds = pd.read_csv(preds_file, dtype={"Touch": str})
280
  else:
281
- preds = infer(MODEL, replay_file,
282
- nullify_goal_difference=nullify_goal_difference,
283
- ignore_ties=ignore_ties)
284
  plt = plot_plotly(preds)
285
  print(f"Plot generated for file: {replay_file}")
286
  preds.to_csv(preds_file)
 
66
  predictions = []
67
  it = trange(len(serialized_states), desc="Running model")
68
  for i in range(0, len(serialized_states), bs):
69
+ batch = (serialized_states[i:i + bs].clone().to(DEVICE),
70
  serialized_scoreboards[i:i + bs].clone().to(DEVICE))
71
  if nullify_goal_difference or ignore_ties:
72
  batch[1][:, SB_BLUE_SCORE] = 0
 
115
  preds = preds[main_cols + [c for c in preds.columns if c not in main_cols]]
116
  # Set index name
117
  preds.index.name = "Frame"
118
+ remove_ties_mask = is_ot if not ignore_ties else torch.ones(len(preds), dtype=torch.bool)
119
+ if remove_ties_mask.any():
120
+ tie_probs = preds[remove_ties_mask, "Tie"]
121
  q = (1 - tie_probs)
 
122
  for c in preds.columns:
123
  if c.startswith("Blue") or c.startswith("Orange"):
124
+ preds[remove_ties_mask, c] /= q
125
+ if ignore_ties:
126
+ preds = preds.drop("Tie", axis=1)
127
+ else:
128
+ preds[remove_ties_mask, "Tie"] = 0.0
129
 
130
  return preds
131
 
 
246
  return fig
247
 
248
 
249
+ DESCRIPTION = """
250
+ # Next Goal Predictor
251
+ Upload a replay file to get a plot of the next goal prediction.
252
+
253
+ The model is trained on about 50,000 hours of SSL and RLCS replays in 1s, 2s, and 3s.
254
+
255
+ It predicts the probability of each team scoring the next goal, in addition to ties.
256
+ It also predicts the probability that each team will score in 1s, 2s, etc. up to 60+ seconds.
257
+
258
+ The plot only shows the team predictions, but you can download the full predictions.
259
+ """.strip()
260
+
261
+ RADIO_OPTIONS = [
262
+ "**Nullify goal difference**<br>Makes the model think the goal difference is always 0, so it doesn't have a bias towards one team.",
263
+ "**Ignore ties**<br>Makes the model pretend every situation is an overtime (e.g. ties are impossible)."
264
+ ]
265
+
266
+
267
  def gradio_app():
268
  import gradio as gr
269
 
270
  with TemporaryDirectory() as temp_dir:
271
  with gr.Blocks() as demo:
272
+ gr.Markdown(DESCRIPTION)
 
 
273
 
274
  # Use gr.Column to stack components vertically
275
  with gr.Column():
276
  file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
277
  checkboxes = gr.Radio(label="Options", choices=["Nullify goal difference", "Ignore ties"], type="index",
278
+ info=RADIO_INFO)
 
279
  submit_button = gr.Button("Generate Predictions")
280
  plot_output = gr.Plot(label="Predictions")
281
  download_button = gr.DownloadButton("Download Predictions", visible=False)
 
285
  nullify_goal_difference = radio_option == 0
286
  ignore_ties = radio_option == 1
287
  print(f"Processing file: {replay_file}")
288
+
289
  replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
290
  postfix = ""
291
  if nullify_goal_difference:
 
297
  print(f"Predictions file already exists: {preds_file}")
298
  preds = pd.read_csv(preds_file, dtype={"Touch": str})
299
  else:
300
+ preds = infer(MODEL, replay_file,
301
+ nullify_goal_difference=nullify_goal_difference,
302
+ ignore_ties=ignore_ties)
303
  plt = plot_plotly(preds)
304
  print(f"Plot generated for file: {replay_file}")
305
  preds.to_csv(preds_file)