Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -67,9 +67,11 @@ def infer(model, replay_file,
|
|
67 |
for i in range(0, len(serialized_states), bs):
|
68 |
batch = (serialized_states[i:i + bs].clone().to(DEVICE),
|
69 |
serialized_scoreboards[i:i + bs].clone().to(DEVICE))
|
70 |
-
if nullify_goal_difference:
|
71 |
-
batch[:, SB_BLUE_SCORE] = 0
|
72 |
-
batch[:, SB_ORANGE_SCORE] = 0
|
|
|
|
|
73 |
out = model(*batch)
|
74 |
it.update(len(batch[0]))
|
75 |
|
@@ -112,13 +114,17 @@ def infer(model, replay_file,
|
|
112 |
preds = preds[main_cols + [c for c in preds.columns if c not in main_cols]]
|
113 |
# Set index name
|
114 |
preds.index.name = "Frame"
|
|
|
|
|
|
|
|
|
115 |
|
116 |
return preds
|
117 |
|
118 |
|
119 |
def plot_plotly(preds: pd.DataFrame):
|
120 |
import plotly.graph_objects as go
|
121 |
-
preds_df = preds * 100
|
122 |
timer = preds["Timer"]
|
123 |
fig = go.Figure()
|
124 |
|
@@ -138,10 +144,11 @@ def plot_plotly(preds: pd.DataFrame):
|
|
138 |
go.Scatter(x=preds_df.index, y=preds_df["Orange"],
|
139 |
mode='lines', name='Orange', line=dict(color='orange'),
|
140 |
customdata=timer_text, hovertemplate=hovertemplate))
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
|
|
145 |
|
146 |
# Add the horizontal line at y=50%
|
147 |
fig.add_hline(y=50, line_dash="dash", line_color="black", name="50% Probability")
|
@@ -242,20 +249,24 @@ def gradio_app():
|
|
242 |
# Use gr.Column to stack components vertically
|
243 |
with gr.Column():
|
244 |
file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
|
|
|
245 |
submit_button = gr.Button("Generate Predictions")
|
246 |
plot_output = gr.Plot(label="Predictions")
|
247 |
download_button = gr.DownloadButton("Download Predictions", visible=False)
|
248 |
|
249 |
# Make plot on button click
|
250 |
-
def make_plot(replay_file, progress=gr.Progress(track_tqdm=True)):
|
|
|
251 |
print(f"Processing file: {replay_file}")
|
252 |
replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
|
253 |
-
preds_file = os.path.join(temp_dir, f"predictions_{replay_stem}.csv")
|
254 |
if os.path.exists(preds_file):
|
255 |
print(f"Predictions file already exists: {preds_file}")
|
256 |
preds = pd.read_csv(preds_file)
|
257 |
else:
|
258 |
-
preds = infer(MODEL, replay_file
|
|
|
|
|
259 |
plt = plot_plotly(preds)
|
260 |
print(f"Plot generated for file: {replay_file}")
|
261 |
preds.to_csv(preds_file)
|
@@ -267,7 +278,7 @@ def gradio_app():
|
|
267 |
|
268 |
submit_button.click(
|
269 |
fn=make_plot,
|
270 |
-
inputs=file_input,
|
271 |
outputs=[plot_output, download_button],
|
272 |
show_progress="full",
|
273 |
)
|
|
|
67 |
for i in range(0, len(serialized_states), bs):
|
68 |
batch = (serialized_states[i:i + bs].clone().to(DEVICE),
|
69 |
serialized_scoreboards[i:i + bs].clone().to(DEVICE))
|
70 |
+
if nullify_goal_difference or ignore_ties:
|
71 |
+
batch[1][:, SB_BLUE_SCORE] = 0
|
72 |
+
batch[1][:, SB_ORANGE_SCORE] = 0
|
73 |
+
if ignore_ties:
|
74 |
+
batch[1][:, SB_GAME_TIMER_SECONDS] = float("inf")
|
75 |
out = model(*batch)
|
76 |
it.update(len(batch[0]))
|
77 |
|
|
|
114 |
preds = preds[main_cols + [c for c in preds.columns if c not in main_cols]]
|
115 |
# Set index name
|
116 |
preds.index.name = "Frame"
|
117 |
+
if ignore_ties:
|
118 |
+
tie_probs = preds["Tie"]
|
119 |
+
preds = preds.drop("Tie")
|
120 |
+
preds[[c for c in preds.columns if c.startswith("Blue") or c.startswith("Orange")] /= (1 - tie_probs)
|
121 |
|
122 |
return preds
|
123 |
|
124 |
|
125 |
def plot_plotly(preds: pd.DataFrame):
|
126 |
import plotly.graph_objects as go
|
127 |
+
preds_df = preds.drop(["Touch", "Timer", "Goal"], axis=1) * 100
|
128 |
timer = preds["Timer"]
|
129 |
fig = go.Figure()
|
130 |
|
|
|
144 |
go.Scatter(x=preds_df.index, y=preds_df["Orange"],
|
145 |
mode='lines', name='Orange', line=dict(color='orange'),
|
146 |
customdata=timer_text, hovertemplate=hovertemplate))
|
147 |
+
if "Tie" in preds.columns:
|
148 |
+
fig.add_trace(
|
149 |
+
go.Scatter(x=preds_df.index, y=preds_df["Tie"],
|
150 |
+
mode='lines', name='Tie', line=dict(color='gray'),
|
151 |
+
customdata=timer_text, hovertemplate=hovertemplate))
|
152 |
|
153 |
# Add the horizontal line at y=50%
|
154 |
fig.add_hline(y=50, line_dash="dash", line_color="black", name="50% Probability")
|
|
|
249 |
# Use gr.Column to stack components vertically
|
250 |
with gr.Column():
|
251 |
file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
|
252 |
+
checkboxes = gr.CheckboxGroup(["Nullify goal difference", "Ignore ties"])
|
253 |
submit_button = gr.Button("Generate Predictions")
|
254 |
plot_output = gr.Plot(label="Predictions")
|
255 |
download_button = gr.DownloadButton("Download Predictions", visible=False)
|
256 |
|
257 |
# Make plot on button click
|
258 |
+
def make_plot(replay_file, checkbox_options, progress=gr.Progress(track_tqdm=True)):
|
259 |
+
nullify_goal_difference, ignore_ties = checkbox_options
|
260 |
print(f"Processing file: {replay_file}")
|
261 |
replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
|
262 |
+
preds_file = os.path.join(temp_dir, f"predictions_{replay_stem}_{options}.csv")
|
263 |
if os.path.exists(preds_file):
|
264 |
print(f"Predictions file already exists: {preds_file}")
|
265 |
preds = pd.read_csv(preds_file)
|
266 |
else:
|
267 |
+
preds = infer(MODEL, replay_file,
|
268 |
+
nullify_goal_difference=nullify_goal_difference,
|
269 |
+
ignore_ties=ignore_ties)
|
270 |
plt = plot_plotly(preds)
|
271 |
print(f"Plot generated for file: {replay_file}")
|
272 |
preds.to_csv(preds_file)
|
|
|
278 |
|
279 |
submit_button.click(
|
280 |
fn=make_plot,
|
281 |
+
inputs=[file_input, checkboxes],
|
282 |
outputs=[plot_output, download_button],
|
283 |
show_progress="full",
|
284 |
)
|