Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,154 +1,273 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
import random
|
4 |
|
5 |
-
|
6 |
-
|
7 |
import torch
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
|
11 |
-
|
12 |
-
if torch.cuda.is_available():
|
13 |
-
torch_dtype = torch.float16
|
14 |
-
else:
|
15 |
-
torch_dtype = torch.float32
|
16 |
-
|
17 |
-
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
|
18 |
-
pipe = pipe.to(device)
|
19 |
-
|
20 |
-
MAX_SEED = np.iinfo(np.int32).max
|
21 |
-
MAX_IMAGE_SIZE = 1024
|
22 |
-
|
23 |
-
|
24 |
-
# @spaces.GPU #[uncomment to use ZeroGPU]
|
25 |
-
def infer(
|
26 |
-
prompt,
|
27 |
-
negative_prompt,
|
28 |
-
seed,
|
29 |
-
randomize_seed,
|
30 |
-
width,
|
31 |
-
height,
|
32 |
-
guidance_scale,
|
33 |
-
num_inference_steps,
|
34 |
-
progress=gr.Progress(track_tqdm=True),
|
35 |
-
):
|
36 |
-
if randomize_seed:
|
37 |
-
seed = random.randint(0, MAX_SEED)
|
38 |
-
|
39 |
-
generator = torch.Generator().manual_seed(seed)
|
40 |
-
|
41 |
-
image = pipe(
|
42 |
-
prompt=prompt,
|
43 |
-
negative_prompt=negative_prompt,
|
44 |
-
guidance_scale=guidance_scale,
|
45 |
-
num_inference_steps=num_inference_steps,
|
46 |
-
width=width,
|
47 |
-
height=height,
|
48 |
-
generator=generator,
|
49 |
-
).images[0]
|
50 |
-
|
51 |
-
return image, seed
|
52 |
-
|
53 |
-
|
54 |
-
examples = [
|
55 |
-
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
56 |
-
"An astronaut riding a green horse",
|
57 |
-
"A delicious ceviche cheesecake slice",
|
58 |
-
]
|
59 |
-
|
60 |
-
css = """
|
61 |
-
#col-container {
|
62 |
-
margin: 0 auto;
|
63 |
-
max-width: 640px;
|
64 |
-
}
|
65 |
-
"""
|
66 |
-
|
67 |
-
with gr.Blocks(css=css) as demo:
|
68 |
-
with gr.Column(elem_id="col-container"):
|
69 |
-
gr.Markdown(" # Text-to-Image Gradio Template")
|
70 |
-
|
71 |
-
with gr.Row():
|
72 |
-
prompt = gr.Text(
|
73 |
-
label="Prompt",
|
74 |
-
show_label=False,
|
75 |
-
max_lines=1,
|
76 |
-
placeholder="Enter your prompt",
|
77 |
-
container=False,
|
78 |
-
)
|
79 |
|
80 |
-
|
81 |
|
82 |
-
|
|
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
-
seed = gr.Slider(
|
93 |
-
label="Seed",
|
94 |
-
minimum=0,
|
95 |
-
maximum=MAX_SEED,
|
96 |
-
step=1,
|
97 |
-
value=0,
|
98 |
-
)
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
]
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
)
|
152 |
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tempfile import TemporaryDirectory
|
|
|
3 |
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
import torch
|
7 |
+
from rlgym_tools.rocket_league.misc.serialize import serialize_replay_frame, RF_SCOREBOARD, SB_GAME_TIMER_SECONDS,
|
8 |
+
from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym
|
9 |
+
from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay
|
10 |
+
from tqdm import trange, tqdm
|
11 |
|
12 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
from huggingface_hub import Repository
|
15 |
|
16 |
+
repo = Repository(local_dir="vortex-ngp", clone_from="Rolv-Arild/vortex-ngp")
|
17 |
+
repo.git_pull()
|
18 |
|
19 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
MODEL = torch.jit.load("vortex-ngp/vortex_ngp-avid-frog.pt", map_location=DEVICE)
|
21 |
+
MODEL.eval()
|
22 |
+
|
23 |
+
|
24 |
+
@torch.inference_mode()
|
25 |
+
def infer(model, replay_file,
|
26 |
+
nullify_goal_difference=False,
|
27 |
+
predict_with_all_quadrants=True,
|
28 |
+
ignore_ties=False):
|
29 |
+
swap_team_idx = torch.arange(model.num_outputs)
|
30 |
+
mid = model.num_outputs // 2
|
31 |
+
swap_team_idx[mid:-1] = swap_team_idx[:mid]
|
32 |
+
swap_team_idx[:mid] += model.num_outputs // 2
|
33 |
+
|
34 |
+
replay = ParsedReplay.load(replay_file)
|
35 |
+
it = tqdm(replay_to_rlgym(replay), desc="Loading replay", total=len(replay.game_df))
|
36 |
+
replay_frames = []
|
37 |
+
serialized = []
|
38 |
+
for replay_frame in it:
|
39 |
+
replay_frames.append(replay_frame)
|
40 |
+
s = serialize_replay_frame(replay_frame)
|
41 |
+
serialized.append(s)
|
42 |
+
serialized = np.stack(serialized)
|
43 |
+
it.close()
|
44 |
+
|
45 |
+
scoreboard = serialized[:, RF_SCOREBOARD]
|
46 |
+
|
47 |
+
timer = scoreboard[:, SB_GAME_TIMER_SECONDS]
|
48 |
+
is_ot = timer > 450
|
49 |
+
ot_time_remaining = serialized[is_ot, RF_EPISODE_SECONDS_REMAINING]
|
50 |
+
if len(ot_time_remaining) > 0:
|
51 |
+
ot_timer = ot_time_remaining[0] - ot_time_remaining
|
52 |
+
timer[is_ot] = -ot_timer # Negate to indicate overtime
|
53 |
+
|
54 |
+
goal_diff = scoreboard[:, SB_BLUE_SCORE] - scoreboard[:, SB_ORANGE_SCORE]
|
55 |
+
goal_diff_diff = goal_diff.diff(prepend=torch.Tensor([0]))
|
56 |
+
|
57 |
+
# if nullify_goal_difference:
|
58 |
+
# features.scoreboard[..., 2] = 0
|
59 |
+
|
60 |
+
bs = 512
|
61 |
+
predictions = []
|
62 |
+
it = trange(len(serialized), desc="Running model")
|
63 |
+
for i in range(0, len(serialized), bs):
|
64 |
+
batch = serialized[i:i + bs].clone().to(DEVICE)
|
65 |
+
if nullify_goal_difference:
|
66 |
+
batch[:, SB_BLUE_SCORE] = 0
|
67 |
+
batch[:, SB_ORANGE_SCORE] = 0
|
68 |
+
out = model(batch)
|
69 |
+
it.update(len(batch))
|
70 |
+
|
71 |
+
predictions.append(out)
|
72 |
+
|
73 |
+
predictions = torch.cat(predictions, dim=0)
|
74 |
+
probs = predictions.softmax(dim=-1)
|
75 |
+
|
76 |
+
bin_seconds = torch.linspace(0, 60, model.num_outputs // 2)
|
77 |
+
class_names = [
|
78 |
+
f"{t}: {s:g}s" for t in ["Blue", "Orange"]
|
79 |
+
for s in bin_seconds.tolist()
|
80 |
+
]
|
81 |
+
class_names.append("Tie")
|
82 |
+
|
83 |
+
preds = probs.cpu().numpy()
|
84 |
+
preds = pd.DataFrame(data=preds, columns=class_names)
|
85 |
+
preds["Blue"] = preds[[c for c in preds.columns if c.startswith("Blue")]].sum(axis=1)
|
86 |
+
preds["Orange"] = preds[[c for c in preds.columns if c.startswith("Orange")]].sum(axis=1)
|
87 |
+
preds["Timer"] = timer
|
88 |
+
preds["Goal"] = goal_diff_diff
|
89 |
+
preds["Touch"] = ""
|
90 |
+
|
91 |
+
pid_to_name = {int(p["unique_id"]): p["name"]
|
92 |
+
for p in replay.metadata["players"]
|
93 |
+
if p["unique_id"] in replay.player_dfs}
|
94 |
+
for i, replay_frame in enumerate(replay_frames):
|
95 |
+
state = replay_frame.state
|
96 |
+
for aid, car in state.cars.items():
|
97 |
+
if car.ball_touches > 0:
|
98 |
+
team = "Blue" if car.is_blue else "Orange"
|
99 |
+
name = pid_to_name[aid]
|
100 |
+
name = name.replace("|", " ") # Replace pipe with space to not conflict with sep
|
101 |
+
if preds.at[i, "Touch"] != "":
|
102 |
+
preds.at[i, "Touch"] += "|"
|
103 |
+
preds.at[i, "Touch"] += f"{team}|{name}"
|
104 |
+
|
105 |
+
# Sort columns
|
106 |
+
main_cols = ["Timer", "Blue", "Orange", "Tie", "Goal", "Touch"]
|
107 |
+
preds = preds[main_cols + [c for c in preds.columns if c not in main_cols]]
|
108 |
+
# Set index name
|
109 |
+
preds.index.name = "Frame"
|
110 |
+
|
111 |
+
return preds
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
+
def plot_plotly(preds: pd.DataFrame):
|
115 |
+
import plotly.graph_objects as go
|
116 |
+
preds_df = preds * 100
|
117 |
+
timer = preds["Timer"]
|
118 |
+
fig = go.Figure()
|
119 |
+
|
120 |
+
def format_timer(t):
|
121 |
+
sign = '+' if t < 0 else ''
|
122 |
+
return f"{sign}{abs(t) // 60:01.0f}:{abs(t) % 60:02.0f}"
|
123 |
+
|
124 |
+
timer_text = [format_timer(t.item()) for t in timer.values]
|
125 |
+
|
126 |
+
hovertemplate = '<b>Frame %{x}</b><br>Prob: %{y:.3g}%<br>Timer: %{customdata}<extra></extra>'
|
127 |
+
# Add traces for Blue, Orange, and Tie probabilities from the DataFrame
|
128 |
+
fig.add_trace(
|
129 |
+
go.Scatter(x=preds_df.index, y=preds_df["Blue"],
|
130 |
+
mode='lines', name='Blue', line=dict(color='blue'),
|
131 |
+
customdata=timer_text, hovertemplate=hovertemplate))
|
132 |
+
fig.add_trace(
|
133 |
+
go.Scatter(x=preds_df.index, y=preds_df["Orange"],
|
134 |
+
mode='lines', name='Orange', line=dict(color='orange'),
|
135 |
+
customdata=timer_text, hovertemplate=hovertemplate))
|
136 |
+
fig.add_trace(
|
137 |
+
go.Scatter(x=preds_df.index, y=preds_df["Tie"],
|
138 |
+
mode='lines', name='Tie', line=dict(color='gray'),
|
139 |
+
customdata=timer_text, hovertemplate=hovertemplate))
|
140 |
+
|
141 |
+
# Add the horizontal line at y=50%
|
142 |
+
fig.add_hline(y=50, line_dash="dash", line_color="black", name="50% Probability")
|
143 |
+
|
144 |
+
# Add goal indicators
|
145 |
+
b = o = 0
|
146 |
+
for goal_frame in preds["Goal"].index[preds["Goal"] != 0]:
|
147 |
+
if preds["Goal"][goal_frame] > 0:
|
148 |
+
b += 1
|
149 |
+
elif preds["Goal"][goal_frame] < 0:
|
150 |
+
o += 1
|
151 |
+
fig.add_vline(x=goal_frame, line_dash="dash", line_color="red",
|
152 |
+
annotation_text=f"{b}-{o}", annotation_position="top right")
|
153 |
+
|
154 |
+
# Add touch indicators as points
|
155 |
+
touches = {}
|
156 |
+
for touch_frame in preds.index[preds["Touch"] != ""]:
|
157 |
+
teams_players = preds["Touch"][touch_frame].split('|')
|
158 |
+
for team, player in zip(teams_players[::2], teams_players[1::2]):
|
159 |
+
team = team.strip()
|
160 |
+
player = player.strip()
|
161 |
+
touches.setdefault(team, []).append((touch_frame, player))
|
162 |
+
for team in "Blue", "Orange":
|
163 |
+
team_touches = touches.get(team, [])
|
164 |
+
if not team_touches:
|
165 |
+
continue
|
166 |
+
x = [t[0] for t in team_touches]
|
167 |
+
y = [preds_df.at[t[0], team] for t in team_touches]
|
168 |
+
touch_players = [t[1] for t in team_touches]
|
169 |
+
custom_data = [f"{timer_text[f]}<br>Touch by {p}"
|
170 |
+
for f, p in zip(x, touch_players)]
|
171 |
+
fig.add_trace(
|
172 |
+
go.Scatter(x=x, y=y,
|
173 |
+
mode='markers',
|
174 |
+
name=f'{team} touches',
|
175 |
+
marker=dict(size=5, color=team.lower(), symbol='circle-open-dot'),
|
176 |
+
customdata=custom_data,
|
177 |
+
hovertemplate=hovertemplate
|
178 |
+
))
|
179 |
+
|
180 |
+
# Define the formatting function for the secondary x-axis labels
|
181 |
+
def format_timer_ticks(x):
|
182 |
+
"""Converts a frame number to a formatted time string."""
|
183 |
+
x = int(x)
|
184 |
+
# Ensure the index is within the bounds of the timer series
|
185 |
+
x = max(0, min(x, len(timer) - 1))
|
186 |
+
|
187 |
+
# Calculate the time value
|
188 |
+
t = timer.iloc[x] * 300
|
189 |
+
|
190 |
+
# Format the time as MM:SS, with a '+' for negative values (representing overtime)
|
191 |
+
sign = '+' if t < 0 else ''
|
192 |
+
minutes = int(abs(t) // 60)
|
193 |
+
seconds = int(abs(t) % 60)
|
194 |
+
return f"{sign}{minutes:01}:{seconds:02}"
|
195 |
+
|
196 |
+
# Generate positions and labels for the secondary axis ticks
|
197 |
+
# Creates 10 evenly spaced ticks for clarity
|
198 |
+
tick_positions = np.linspace(0, len(preds_df) - 1, 10)
|
199 |
+
tick_labels = [format_timer_ticks(val) for val in tick_positions]
|
200 |
+
|
201 |
+
# Configure the figure's layout, titles, and both x-axes
|
202 |
+
fig.update_layout(
|
203 |
+
title="Interactive Probability Plot",
|
204 |
+
xaxis=dict(
|
205 |
+
title="Frame",
|
206 |
+
gridcolor='#e5e7eb' # A light gray grid for a modern look
|
207 |
+
),
|
208 |
+
yaxis=dict(
|
209 |
+
title="Probability",
|
210 |
+
gridcolor='#e5e7eb'
|
211 |
+
),
|
212 |
+
# --- Secondary X-Axis Configuration ---
|
213 |
+
xaxis2=dict(
|
214 |
+
title="Timer",
|
215 |
+
overlaying='x', # This makes it a secondary axis
|
216 |
+
side='top', # Position it at the top
|
217 |
+
tickmode='array',
|
218 |
+
tickvals=tick_positions,
|
219 |
+
ticktext=tick_labels
|
220 |
+
),
|
221 |
+
legend=dict(x=0.01, y=0.99, yanchor="top", xanchor="left"), # Position legend inside plot
|
222 |
+
plot_bgcolor='white' # A clean white background
|
223 |
)
|
224 |
|
225 |
+
# fig.show()
|
226 |
+
return fig
|
227 |
+
|
228 |
+
|
229 |
+
def gradio_app():
|
230 |
+
import gradio as gr
|
231 |
+
|
232 |
+
with TemporaryDirectory() as temp_dir:
|
233 |
+
with gr.Blocks() as demo:
|
234 |
+
gr.Markdown("# Next Goal Predictor")
|
235 |
+
gr.Markdown("Upload a replay file to get a plot of the next goal prediction.")
|
236 |
+
|
237 |
+
# Use gr.Column to stack components vertically
|
238 |
+
with gr.Column():
|
239 |
+
file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
|
240 |
+
submit_button = gr.Button("Generate Predictions")
|
241 |
+
plot_output = gr.Plot(label="Predictions")
|
242 |
+
download_button = gr.DownloadButton("Download Predictions", visible=False)
|
243 |
+
|
244 |
+
# Make plot on button click
|
245 |
+
def make_plot(replay_file, progress=gr.Progress(track_tqdm=True)):
|
246 |
+
print(f"Processing file: {replay_file}")
|
247 |
+
replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
|
248 |
+
preds_file = os.path.join(temp_dir, f"predictions_{replay_stem}.csv")
|
249 |
+
if os.path.exists(preds_file):
|
250 |
+
print(f"Predictions file already exists: {preds_file}")
|
251 |
+
preds = pd.read_csv(preds_file)
|
252 |
+
else:
|
253 |
+
preds = infer(MODEL, replay_file)
|
254 |
+
plt = plot_plotly(preds)
|
255 |
+
print(f"Plot generated for file: {replay_file}")
|
256 |
+
preds.to_csv(preds_file)
|
257 |
+
if len(os.listdir(temp_dir)) > 100:
|
258 |
+
# Delete least recent file
|
259 |
+
oldest_file = min(os.listdir(temp_dir), key=lambda f: os.path.getctime(os.path.join(temp_dir, f)))
|
260 |
+
os.remove(os.path.join(temp_dir, oldest_file))
|
261 |
+
return plt, gr.DownloadButton(value=preds_file, visible=True)
|
262 |
+
|
263 |
+
submit_button.click(
|
264 |
+
fn=make_plot,
|
265 |
+
inputs=file_input,
|
266 |
+
outputs=[plot_output, download_button],
|
267 |
+
show_progress="full",
|
268 |
+
)
|
269 |
+
demo.launch()
|
270 |
+
|
271 |
+
|
272 |
+
if __name__ == '__main__':
|
273 |
+
gradio_app()
|