Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -66,7 +66,7 @@ def infer(model, replay_file,
|
|
66 |
predictions = []
|
67 |
it = trange(len(serialized_states), desc="Running model")
|
68 |
for i in range(0, len(serialized_states), bs):
|
69 |
-
batch = (serialized_states[i:i + bs].clone().to(DEVICE),
|
70 |
serialized_scoreboards[i:i + bs].clone().to(DEVICE))
|
71 |
if nullify_goal_difference or ignore_ties:
|
72 |
batch[1][:, SB_BLUE_SCORE] = 0
|
@@ -115,13 +115,17 @@ def infer(model, replay_file,
|
|
115 |
preds = preds[main_cols + [c for c in preds.columns if c not in main_cols]]
|
116 |
# Set index name
|
117 |
preds.index.name = "Frame"
|
118 |
-
if ignore_ties
|
119 |
-
|
|
|
120 |
q = (1 - tie_probs)
|
121 |
-
preds = preds.drop("Tie", axis=1)
|
122 |
for c in preds.columns:
|
123 |
if c.startswith("Blue") or c.startswith("Orange"):
|
124 |
-
preds[c] /= q
|
|
|
|
|
|
|
|
|
125 |
|
126 |
return preds
|
127 |
|
@@ -242,21 +246,36 @@ def plot_plotly(preds: pd.DataFrame):
|
|
242 |
return fig
|
243 |
|
244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
def gradio_app():
|
246 |
import gradio as gr
|
247 |
|
248 |
with TemporaryDirectory() as temp_dir:
|
249 |
with gr.Blocks() as demo:
|
250 |
-
gr.Markdown(
|
251 |
-
gr.Markdown("Upload a replay file to get a plot of the next goal prediction.<br>"
|
252 |
-
"The model is trained on about 50 000 hours of SSL and RLCS replays in 1s, 2s, and 3s.")
|
253 |
|
254 |
# Use gr.Column to stack components vertically
|
255 |
with gr.Column():
|
256 |
file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
|
257 |
checkboxes = gr.Radio(label="Options", choices=["Nullify goal difference", "Ignore ties"], type="index",
|
258 |
-
|
259 |
-
"or make it pretend every situation is an overtime (e.g. ties are impossible).")
|
260 |
submit_button = gr.Button("Generate Predictions")
|
261 |
plot_output = gr.Plot(label="Predictions")
|
262 |
download_button = gr.DownloadButton("Download Predictions", visible=False)
|
@@ -266,7 +285,7 @@ def gradio_app():
|
|
266 |
nullify_goal_difference = radio_option == 0
|
267 |
ignore_ties = radio_option == 1
|
268 |
print(f"Processing file: {replay_file}")
|
269 |
-
|
270 |
replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
|
271 |
postfix = ""
|
272 |
if nullify_goal_difference:
|
@@ -278,9 +297,9 @@ def gradio_app():
|
|
278 |
print(f"Predictions file already exists: {preds_file}")
|
279 |
preds = pd.read_csv(preds_file, dtype={"Touch": str})
|
280 |
else:
|
281 |
-
preds = infer(MODEL, replay_file,
|
282 |
-
|
283 |
-
|
284 |
plt = plot_plotly(preds)
|
285 |
print(f"Plot generated for file: {replay_file}")
|
286 |
preds.to_csv(preds_file)
|
|
|
66 |
predictions = []
|
67 |
it = trange(len(serialized_states), desc="Running model")
|
68 |
for i in range(0, len(serialized_states), bs):
|
69 |
+
batch = (serialized_states[i:i + bs].clone().to(DEVICE),
|
70 |
serialized_scoreboards[i:i + bs].clone().to(DEVICE))
|
71 |
if nullify_goal_difference or ignore_ties:
|
72 |
batch[1][:, SB_BLUE_SCORE] = 0
|
|
|
115 |
preds = preds[main_cols + [c for c in preds.columns if c not in main_cols]]
|
116 |
# Set index name
|
117 |
preds.index.name = "Frame"
|
118 |
+
remove_ties_mask = is_ot if not ignore_ties else torch.ones(len(preds), dtype=torch.bool)
|
119 |
+
if remove_ties_mask.any():
|
120 |
+
tie_probs = preds[remove_ties_mask, "Tie"]
|
121 |
q = (1 - tie_probs)
|
|
|
122 |
for c in preds.columns:
|
123 |
if c.startswith("Blue") or c.startswith("Orange"):
|
124 |
+
preds[remove_ties_mask, c] /= q
|
125 |
+
if ignore_ties:
|
126 |
+
preds = preds.drop("Tie", axis=1)
|
127 |
+
else:
|
128 |
+
preds[remove_ties_mask, "Tie"] = 0.0
|
129 |
|
130 |
return preds
|
131 |
|
|
|
246 |
return fig
|
247 |
|
248 |
|
249 |
+
DESCRIPTION = """
|
250 |
+
# Next Goal Predictor
|
251 |
+
Upload a replay file to get a plot of the next goal prediction.
|
252 |
+
|
253 |
+
The model is trained on about 50,000 hours of SSL and RLCS replays in 1s, 2s, and 3s.
|
254 |
+
|
255 |
+
It predicts the probability of each team scoring the next goal, in addition to ties.
|
256 |
+
It also predicts the probability that each team will score in 1s, 2s, etc. up to 60+ seconds.
|
257 |
+
|
258 |
+
The plot only shows the team predictions, but you can download the full predictions.
|
259 |
+
""".strip()
|
260 |
+
|
261 |
+
RADIO_OPTIONS = [
|
262 |
+
"**Nullify goal difference**<br>Makes the model think the goal difference is always 0, so it doesn't have a bias towards one team.",
|
263 |
+
"**Ignore ties**<br>Makes the model pretend every situation is an overtime (e.g. ties are impossible)."
|
264 |
+
]
|
265 |
+
|
266 |
+
|
267 |
def gradio_app():
|
268 |
import gradio as gr
|
269 |
|
270 |
with TemporaryDirectory() as temp_dir:
|
271 |
with gr.Blocks() as demo:
|
272 |
+
gr.Markdown(DESCRIPTION)
|
|
|
|
|
273 |
|
274 |
# Use gr.Column to stack components vertically
|
275 |
with gr.Column():
|
276 |
file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
|
277 |
checkboxes = gr.Radio(label="Options", choices=["Nullify goal difference", "Ignore ties"], type="index",
|
278 |
+
info=RADIO_INFO)
|
|
|
279 |
submit_button = gr.Button("Generate Predictions")
|
280 |
plot_output = gr.Plot(label="Predictions")
|
281 |
download_button = gr.DownloadButton("Download Predictions", visible=False)
|
|
|
285 |
nullify_goal_difference = radio_option == 0
|
286 |
ignore_ties = radio_option == 1
|
287 |
print(f"Processing file: {replay_file}")
|
288 |
+
|
289 |
replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
|
290 |
postfix = ""
|
291 |
if nullify_goal_difference:
|
|
|
297 |
print(f"Predictions file already exists: {preds_file}")
|
298 |
preds = pd.read_csv(preds_file, dtype={"Touch": str})
|
299 |
else:
|
300 |
+
preds = infer(MODEL, replay_file,
|
301 |
+
nullify_goal_difference=nullify_goal_difference,
|
302 |
+
ignore_ties=ignore_ties)
|
303 |
plt = plot_plotly(preds)
|
304 |
print(f"Plot generated for file: {replay_file}")
|
305 |
preds.to_csv(preds_file)
|