Rolv-Arild commited on
Commit
43e1328
·
verified ·
1 Parent(s): 59ac7e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -12
app.py CHANGED
@@ -67,9 +67,11 @@ def infer(model, replay_file,
67
  for i in range(0, len(serialized_states), bs):
68
  batch = (serialized_states[i:i + bs].clone().to(DEVICE),
69
  serialized_scoreboards[i:i + bs].clone().to(DEVICE))
70
- if nullify_goal_difference:
71
- batch[:, SB_BLUE_SCORE] = 0
72
- batch[:, SB_ORANGE_SCORE] = 0
 
 
73
  out = model(*batch)
74
  it.update(len(batch[0]))
75
 
@@ -112,13 +114,17 @@ def infer(model, replay_file,
112
  preds = preds[main_cols + [c for c in preds.columns if c not in main_cols]]
113
  # Set index name
114
  preds.index.name = "Frame"
 
 
 
 
115
 
116
  return preds
117
 
118
 
119
  def plot_plotly(preds: pd.DataFrame):
120
  import plotly.graph_objects as go
121
- preds_df = preds * 100
122
  timer = preds["Timer"]
123
  fig = go.Figure()
124
 
@@ -138,10 +144,11 @@ def plot_plotly(preds: pd.DataFrame):
138
  go.Scatter(x=preds_df.index, y=preds_df["Orange"],
139
  mode='lines', name='Orange', line=dict(color='orange'),
140
  customdata=timer_text, hovertemplate=hovertemplate))
141
- fig.add_trace(
142
- go.Scatter(x=preds_df.index, y=preds_df["Tie"],
143
- mode='lines', name='Tie', line=dict(color='gray'),
144
- customdata=timer_text, hovertemplate=hovertemplate))
 
145
 
146
  # Add the horizontal line at y=50%
147
  fig.add_hline(y=50, line_dash="dash", line_color="black", name="50% Probability")
@@ -242,20 +249,24 @@ def gradio_app():
242
  # Use gr.Column to stack components vertically
243
  with gr.Column():
244
  file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
 
245
  submit_button = gr.Button("Generate Predictions")
246
  plot_output = gr.Plot(label="Predictions")
247
  download_button = gr.DownloadButton("Download Predictions", visible=False)
248
 
249
  # Make plot on button click
250
- def make_plot(replay_file, progress=gr.Progress(track_tqdm=True)):
 
251
  print(f"Processing file: {replay_file}")
252
  replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
253
- preds_file = os.path.join(temp_dir, f"predictions_{replay_stem}.csv")
254
  if os.path.exists(preds_file):
255
  print(f"Predictions file already exists: {preds_file}")
256
  preds = pd.read_csv(preds_file)
257
  else:
258
- preds = infer(MODEL, replay_file)
 
 
259
  plt = plot_plotly(preds)
260
  print(f"Plot generated for file: {replay_file}")
261
  preds.to_csv(preds_file)
@@ -267,7 +278,7 @@ def gradio_app():
267
 
268
  submit_button.click(
269
  fn=make_plot,
270
- inputs=file_input,
271
  outputs=[plot_output, download_button],
272
  show_progress="full",
273
  )
 
67
  for i in range(0, len(serialized_states), bs):
68
  batch = (serialized_states[i:i + bs].clone().to(DEVICE),
69
  serialized_scoreboards[i:i + bs].clone().to(DEVICE))
70
+ if nullify_goal_difference or ignore_ties:
71
+ batch[1][:, SB_BLUE_SCORE] = 0
72
+ batch[1][:, SB_ORANGE_SCORE] = 0
73
+ if ignore_ties:
74
+ batch[1][:, SB_GAME_TIMER_SECONDS] = float("inf")
75
  out = model(*batch)
76
  it.update(len(batch[0]))
77
 
 
114
  preds = preds[main_cols + [c for c in preds.columns if c not in main_cols]]
115
  # Set index name
116
  preds.index.name = "Frame"
117
+ if ignore_ties:
118
+ tie_probs = preds["Tie"]
119
+ preds = preds.drop("Tie")
120
+ preds[[c for c in preds.columns if c.startswith("Blue") or c.startswith("Orange")] /= (1 - tie_probs)
121
 
122
  return preds
123
 
124
 
125
  def plot_plotly(preds: pd.DataFrame):
126
  import plotly.graph_objects as go
127
+ preds_df = preds.drop(["Touch", "Timer", "Goal"], axis=1) * 100
128
  timer = preds["Timer"]
129
  fig = go.Figure()
130
 
 
144
  go.Scatter(x=preds_df.index, y=preds_df["Orange"],
145
  mode='lines', name='Orange', line=dict(color='orange'),
146
  customdata=timer_text, hovertemplate=hovertemplate))
147
+ if "Tie" in preds.columns:
148
+ fig.add_trace(
149
+ go.Scatter(x=preds_df.index, y=preds_df["Tie"],
150
+ mode='lines', name='Tie', line=dict(color='gray'),
151
+ customdata=timer_text, hovertemplate=hovertemplate))
152
 
153
  # Add the horizontal line at y=50%
154
  fig.add_hline(y=50, line_dash="dash", line_color="black", name="50% Probability")
 
249
  # Use gr.Column to stack components vertically
250
  with gr.Column():
251
  file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
252
+ checkboxes = gr.CheckboxGroup(["Nullify goal difference", "Ignore ties"])
253
  submit_button = gr.Button("Generate Predictions")
254
  plot_output = gr.Plot(label="Predictions")
255
  download_button = gr.DownloadButton("Download Predictions", visible=False)
256
 
257
  # Make plot on button click
258
+ def make_plot(replay_file, checkbox_options, progress=gr.Progress(track_tqdm=True)):
259
+ nullify_goal_difference, ignore_ties = checkbox_options
260
  print(f"Processing file: {replay_file}")
261
  replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
262
+ preds_file = os.path.join(temp_dir, f"predictions_{replay_stem}_{options}.csv")
263
  if os.path.exists(preds_file):
264
  print(f"Predictions file already exists: {preds_file}")
265
  preds = pd.read_csv(preds_file)
266
  else:
267
+ preds = infer(MODEL, replay_file,
268
+ nullify_goal_difference=nullify_goal_difference,
269
+ ignore_ties=ignore_ties)
270
  plt = plot_plotly(preds)
271
  print(f"Plot generated for file: {replay_file}")
272
  preds.to_csv(preds_file)
 
278
 
279
  submit_button.click(
280
  fn=make_plot,
281
+ inputs=[file_input, checkboxes],
282
  outputs=[plot_output, download_button],
283
  show_progress="full",
284
  )