Rolv-Arild commited on
Commit
482611e
·
verified ·
1 Parent(s): 9075389

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -144
app.py CHANGED
@@ -1,154 +1,273 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tempfile import TemporaryDirectory
 
3
 
4
+ import numpy as np
5
+ import pandas as pd
6
  import torch
7
+ from rlgym_tools.rocket_league.misc.serialize import serialize_replay_frame, RF_SCOREBOARD, SB_GAME_TIMER_SECONDS,
8
+ from rlgym_tools.rocket_league.replays.convert import replay_to_rlgym
9
+ from rlgym_tools.rocket_league.replays.parsed_replay import ParsedReplay
10
+ from tqdm import trange, tqdm
11
 
12
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ from huggingface_hub import Repository
15
 
16
+ repo = Repository(local_dir="vortex-ngp", clone_from="Rolv-Arild/vortex-ngp")
17
+ repo.git_pull()
18
 
19
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ MODEL = torch.jit.load("vortex-ngp/vortex_ngp-avid-frog.pt", map_location=DEVICE)
21
+ MODEL.eval()
22
+
23
+
24
+ @torch.inference_mode()
25
+ def infer(model, replay_file,
26
+ nullify_goal_difference=False,
27
+ predict_with_all_quadrants=True,
28
+ ignore_ties=False):
29
+ swap_team_idx = torch.arange(model.num_outputs)
30
+ mid = model.num_outputs // 2
31
+ swap_team_idx[mid:-1] = swap_team_idx[:mid]
32
+ swap_team_idx[:mid] += model.num_outputs // 2
33
+
34
+ replay = ParsedReplay.load(replay_file)
35
+ it = tqdm(replay_to_rlgym(replay), desc="Loading replay", total=len(replay.game_df))
36
+ replay_frames = []
37
+ serialized = []
38
+ for replay_frame in it:
39
+ replay_frames.append(replay_frame)
40
+ s = serialize_replay_frame(replay_frame)
41
+ serialized.append(s)
42
+ serialized = np.stack(serialized)
43
+ it.close()
44
+
45
+ scoreboard = serialized[:, RF_SCOREBOARD]
46
+
47
+ timer = scoreboard[:, SB_GAME_TIMER_SECONDS]
48
+ is_ot = timer > 450
49
+ ot_time_remaining = serialized[is_ot, RF_EPISODE_SECONDS_REMAINING]
50
+ if len(ot_time_remaining) > 0:
51
+ ot_timer = ot_time_remaining[0] - ot_time_remaining
52
+ timer[is_ot] = -ot_timer # Negate to indicate overtime
53
+
54
+ goal_diff = scoreboard[:, SB_BLUE_SCORE] - scoreboard[:, SB_ORANGE_SCORE]
55
+ goal_diff_diff = goal_diff.diff(prepend=torch.Tensor([0]))
56
+
57
+ # if nullify_goal_difference:
58
+ # features.scoreboard[..., 2] = 0
59
+
60
+ bs = 512
61
+ predictions = []
62
+ it = trange(len(serialized), desc="Running model")
63
+ for i in range(0, len(serialized), bs):
64
+ batch = serialized[i:i + bs].clone().to(DEVICE)
65
+ if nullify_goal_difference:
66
+ batch[:, SB_BLUE_SCORE] = 0
67
+ batch[:, SB_ORANGE_SCORE] = 0
68
+ out = model(batch)
69
+ it.update(len(batch))
70
+
71
+ predictions.append(out)
72
+
73
+ predictions = torch.cat(predictions, dim=0)
74
+ probs = predictions.softmax(dim=-1)
75
+
76
+ bin_seconds = torch.linspace(0, 60, model.num_outputs // 2)
77
+ class_names = [
78
+ f"{t}: {s:g}s" for t in ["Blue", "Orange"]
79
+ for s in bin_seconds.tolist()
80
+ ]
81
+ class_names.append("Tie")
82
+
83
+ preds = probs.cpu().numpy()
84
+ preds = pd.DataFrame(data=preds, columns=class_names)
85
+ preds["Blue"] = preds[[c for c in preds.columns if c.startswith("Blue")]].sum(axis=1)
86
+ preds["Orange"] = preds[[c for c in preds.columns if c.startswith("Orange")]].sum(axis=1)
87
+ preds["Timer"] = timer
88
+ preds["Goal"] = goal_diff_diff
89
+ preds["Touch"] = ""
90
+
91
+ pid_to_name = {int(p["unique_id"]): p["name"]
92
+ for p in replay.metadata["players"]
93
+ if p["unique_id"] in replay.player_dfs}
94
+ for i, replay_frame in enumerate(replay_frames):
95
+ state = replay_frame.state
96
+ for aid, car in state.cars.items():
97
+ if car.ball_touches > 0:
98
+ team = "Blue" if car.is_blue else "Orange"
99
+ name = pid_to_name[aid]
100
+ name = name.replace("|", " ") # Replace pipe with space to not conflict with sep
101
+ if preds.at[i, "Touch"] != "":
102
+ preds.at[i, "Touch"] += "|"
103
+ preds.at[i, "Touch"] += f"{team}|{name}"
104
+
105
+ # Sort columns
106
+ main_cols = ["Timer", "Blue", "Orange", "Tie", "Goal", "Touch"]
107
+ preds = preds[main_cols + [c for c in preds.columns if c not in main_cols]]
108
+ # Set index name
109
+ preds.index.name = "Frame"
110
+
111
+ return preds
112
 
 
 
 
 
 
 
 
113
 
114
+ def plot_plotly(preds: pd.DataFrame):
115
+ import plotly.graph_objects as go
116
+ preds_df = preds * 100
117
+ timer = preds["Timer"]
118
+ fig = go.Figure()
119
+
120
+ def format_timer(t):
121
+ sign = '+' if t < 0 else ''
122
+ return f"{sign}{abs(t) // 60:01.0f}:{abs(t) % 60:02.0f}"
123
+
124
+ timer_text = [format_timer(t.item()) for t in timer.values]
125
+
126
+ hovertemplate = '<b>Frame %{x}</b><br>Prob: %{y:.3g}%<br>Timer: %{customdata}<extra></extra>'
127
+ # Add traces for Blue, Orange, and Tie probabilities from the DataFrame
128
+ fig.add_trace(
129
+ go.Scatter(x=preds_df.index, y=preds_df["Blue"],
130
+ mode='lines', name='Blue', line=dict(color='blue'),
131
+ customdata=timer_text, hovertemplate=hovertemplate))
132
+ fig.add_trace(
133
+ go.Scatter(x=preds_df.index, y=preds_df["Orange"],
134
+ mode='lines', name='Orange', line=dict(color='orange'),
135
+ customdata=timer_text, hovertemplate=hovertemplate))
136
+ fig.add_trace(
137
+ go.Scatter(x=preds_df.index, y=preds_df["Tie"],
138
+ mode='lines', name='Tie', line=dict(color='gray'),
139
+ customdata=timer_text, hovertemplate=hovertemplate))
140
+
141
+ # Add the horizontal line at y=50%
142
+ fig.add_hline(y=50, line_dash="dash", line_color="black", name="50% Probability")
143
+
144
+ # Add goal indicators
145
+ b = o = 0
146
+ for goal_frame in preds["Goal"].index[preds["Goal"] != 0]:
147
+ if preds["Goal"][goal_frame] > 0:
148
+ b += 1
149
+ elif preds["Goal"][goal_frame] < 0:
150
+ o += 1
151
+ fig.add_vline(x=goal_frame, line_dash="dash", line_color="red",
152
+ annotation_text=f"{b}-{o}", annotation_position="top right")
153
+
154
+ # Add touch indicators as points
155
+ touches = {}
156
+ for touch_frame in preds.index[preds["Touch"] != ""]:
157
+ teams_players = preds["Touch"][touch_frame].split('|')
158
+ for team, player in zip(teams_players[::2], teams_players[1::2]):
159
+ team = team.strip()
160
+ player = player.strip()
161
+ touches.setdefault(team, []).append((touch_frame, player))
162
+ for team in "Blue", "Orange":
163
+ team_touches = touches.get(team, [])
164
+ if not team_touches:
165
+ continue
166
+ x = [t[0] for t in team_touches]
167
+ y = [preds_df.at[t[0], team] for t in team_touches]
168
+ touch_players = [t[1] for t in team_touches]
169
+ custom_data = [f"{timer_text[f]}<br>Touch by {p}"
170
+ for f, p in zip(x, touch_players)]
171
+ fig.add_trace(
172
+ go.Scatter(x=x, y=y,
173
+ mode='markers',
174
+ name=f'{team} touches',
175
+ marker=dict(size=5, color=team.lower(), symbol='circle-open-dot'),
176
+ customdata=custom_data,
177
+ hovertemplate=hovertemplate
178
+ ))
179
+
180
+ # Define the formatting function for the secondary x-axis labels
181
+ def format_timer_ticks(x):
182
+ """Converts a frame number to a formatted time string."""
183
+ x = int(x)
184
+ # Ensure the index is within the bounds of the timer series
185
+ x = max(0, min(x, len(timer) - 1))
186
+
187
+ # Calculate the time value
188
+ t = timer.iloc[x] * 300
189
+
190
+ # Format the time as MM:SS, with a '+' for negative values (representing overtime)
191
+ sign = '+' if t < 0 else ''
192
+ minutes = int(abs(t) // 60)
193
+ seconds = int(abs(t) % 60)
194
+ return f"{sign}{minutes:01}:{seconds:02}"
195
+
196
+ # Generate positions and labels for the secondary axis ticks
197
+ # Creates 10 evenly spaced ticks for clarity
198
+ tick_positions = np.linspace(0, len(preds_df) - 1, 10)
199
+ tick_labels = [format_timer_ticks(val) for val in tick_positions]
200
+
201
+ # Configure the figure's layout, titles, and both x-axes
202
+ fig.update_layout(
203
+ title="Interactive Probability Plot",
204
+ xaxis=dict(
205
+ title="Frame",
206
+ gridcolor='#e5e7eb' # A light gray grid for a modern look
207
+ ),
208
+ yaxis=dict(
209
+ title="Probability",
210
+ gridcolor='#e5e7eb'
211
+ ),
212
+ # --- Secondary X-Axis Configuration ---
213
+ xaxis2=dict(
214
+ title="Timer",
215
+ overlaying='x', # This makes it a secondary axis
216
+ side='top', # Position it at the top
217
+ tickmode='array',
218
+ tickvals=tick_positions,
219
+ ticktext=tick_labels
220
+ ),
221
+ legend=dict(x=0.01, y=0.99, yanchor="top", xanchor="left"), # Position legend inside plot
222
+ plot_bgcolor='white' # A clean white background
223
  )
224
 
225
+ # fig.show()
226
+ return fig
227
+
228
+
229
+ def gradio_app():
230
+ import gradio as gr
231
+
232
+ with TemporaryDirectory() as temp_dir:
233
+ with gr.Blocks() as demo:
234
+ gr.Markdown("# Next Goal Predictor")
235
+ gr.Markdown("Upload a replay file to get a plot of the next goal prediction.")
236
+
237
+ # Use gr.Column to stack components vertically
238
+ with gr.Column():
239
+ file_input = gr.File(label="Upload Replay File", type="filepath", file_types=[".replay"])
240
+ submit_button = gr.Button("Generate Predictions")
241
+ plot_output = gr.Plot(label="Predictions")
242
+ download_button = gr.DownloadButton("Download Predictions", visible=False)
243
+
244
+ # Make plot on button click
245
+ def make_plot(replay_file, progress=gr.Progress(track_tqdm=True)):
246
+ print(f"Processing file: {replay_file}")
247
+ replay_stem = os.path.splitext(os.path.basename(replay_file))[0]
248
+ preds_file = os.path.join(temp_dir, f"predictions_{replay_stem}.csv")
249
+ if os.path.exists(preds_file):
250
+ print(f"Predictions file already exists: {preds_file}")
251
+ preds = pd.read_csv(preds_file)
252
+ else:
253
+ preds = infer(MODEL, replay_file)
254
+ plt = plot_plotly(preds)
255
+ print(f"Plot generated for file: {replay_file}")
256
+ preds.to_csv(preds_file)
257
+ if len(os.listdir(temp_dir)) > 100:
258
+ # Delete least recent file
259
+ oldest_file = min(os.listdir(temp_dir), key=lambda f: os.path.getctime(os.path.join(temp_dir, f)))
260
+ os.remove(os.path.join(temp_dir, oldest_file))
261
+ return plt, gr.DownloadButton(value=preds_file, visible=True)
262
+
263
+ submit_button.click(
264
+ fn=make_plot,
265
+ inputs=file_input,
266
+ outputs=[plot_output, download_button],
267
+ show_progress="full",
268
+ )
269
+ demo.launch()
270
+
271
+
272
+ if __name__ == '__main__':
273
+ gradio_app()