Spaces:
Running
on
Zero
Running
on
Zero
added gpu-based drawing
Browse files
app.py
CHANGED
@@ -4,6 +4,8 @@
|
|
4 |
import os
|
5 |
import sys
|
6 |
import uuid
|
|
|
|
|
7 |
|
8 |
import gradio as gr
|
9 |
import mediapy
|
@@ -93,6 +95,211 @@ def get_points_on_a_grid(
|
|
93 |
)
|
94 |
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
def paint_point_track(
|
97 |
frames: np.ndarray,
|
98 |
point_tracks: np.ndarray,
|
@@ -330,7 +537,8 @@ def preprocess_video_input(video_path):
|
|
330 |
def track(
|
331 |
video_preview,
|
332 |
video_input,
|
333 |
-
video_fps,
|
|
|
334 |
query_points,
|
335 |
query_points_color,
|
336 |
query_count,
|
@@ -338,6 +546,10 @@ def track(
|
|
338 |
# tracking_mode = 'selected'
|
339 |
# if query_count == 0:
|
340 |
# tracking_mode = 'grid'
|
|
|
|
|
|
|
|
|
341 |
|
342 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
343 |
dtype = torch.float if device == "cuda" else torch.float
|
@@ -407,7 +619,7 @@ def track(
|
|
407 |
# add_support_grid=True
|
408 |
|
409 |
|
410 |
-
query_frame = 0
|
411 |
|
412 |
torch.cuda.empty_cache()
|
413 |
|
@@ -444,11 +656,17 @@ def track(
|
|
444 |
# colors.extend(frame_colors)
|
445 |
# colors = np.array(colors)
|
446 |
|
447 |
-
traj_maps_e = traj_maps_e[:,:,:,::4,::4] # subsample
|
448 |
-
visconf_maps_e = visconf_maps_e[:,:,:,::4,::4] # subsample
|
|
|
|
|
449 |
|
450 |
tracks = traj_maps_e.permute(0,3,4,1,2).reshape(-1,T,2).numpy()
|
451 |
-
visibs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
|
|
|
|
|
|
|
|
|
452 |
|
453 |
# sc = (np.array([video_preview.shape[2], video_preview.shape[1]]) / np.array([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]])).reshape(1,1,2)
|
454 |
# print('sc', sc)
|
@@ -467,8 +685,15 @@ def track(
|
|
467 |
for frame_colors in query_points_color:
|
468 |
colors.extend(frame_colors)
|
469 |
colors = np.array(colors)
|
|
|
|
|
|
|
|
|
|
|
470 |
|
471 |
-
painted_video =
|
|
|
|
|
472 |
print("7 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
473 |
|
474 |
# save video
|
@@ -546,7 +771,8 @@ with gr.Blocks() as demo:
|
|
546 |
|
547 |
with gr.Row():
|
548 |
current_frame = gr.Image(
|
549 |
-
label="Click to add query points",
|
|
|
550 |
type="numpy",
|
551 |
interactive=False
|
552 |
)
|
@@ -679,6 +905,7 @@ with gr.Blocks() as demo:
|
|
679 |
video_preview,
|
680 |
video_input,
|
681 |
video_fps,
|
|
|
682 |
query_points,
|
683 |
query_points_color,
|
684 |
query_count,
|
|
|
4 |
import os
|
5 |
import sys
|
6 |
import uuid
|
7 |
+
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
|
9 |
|
10 |
import gradio as gr
|
11 |
import mediapy
|
|
|
95 |
)
|
96 |
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
|
97 |
|
98 |
+
def paint_point_track_gpu_scatter(
|
99 |
+
frames: np.ndarray,
|
100 |
+
point_tracks: np.ndarray,
|
101 |
+
visibles: np.ndarray,
|
102 |
+
colormap: Optional[List[Tuple[int, int, int]]] = None,
|
103 |
+
radius: int = 2,
|
104 |
+
sharpness: float = 0.15,
|
105 |
+
) -> np.ndarray:
|
106 |
+
print('starting vis')
|
107 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
108 |
+
frames_t = torch.from_numpy(frames).float().permute(0, 3, 1, 2).to(device) # [T,C,H,W]
|
109 |
+
point_tracks_t = torch.from_numpy(point_tracks).to(device) # [P,T,2]
|
110 |
+
visibles_t = torch.from_numpy(visibles).to(device) # [P,T]
|
111 |
+
T, C, H, W = frames_t.shape
|
112 |
+
P = point_tracks.shape[0]
|
113 |
+
if colormap is None:
|
114 |
+
colormap = get_colors(P)
|
115 |
+
colors = torch.tensor(colormap, dtype=torch.float32, device=device) # [P,3]
|
116 |
+
D = radius * 2 + 1
|
117 |
+
y = torch.arange(D, device=device).float()[:, None] - radius
|
118 |
+
x = torch.arange(D, device=device).float()[None, :] - radius
|
119 |
+
dist2 = x**2 + y**2
|
120 |
+
icon = torch.clamp(1 - (dist2 - (radius**2) / 2.0) / (radius * 2 * sharpness), 0, 1) # [D,D]
|
121 |
+
icon = icon.view(1, D, D)
|
122 |
+
dx = torch.arange(-radius, radius + 1, device=device)
|
123 |
+
dy = torch.arange(-radius, radius + 1, device=device)
|
124 |
+
disp_y, disp_x = torch.meshgrid(dy, dx, indexing="ij") # [D,D]
|
125 |
+
for t in range(T):
|
126 |
+
mask = visibles_t[:, t] # [P]
|
127 |
+
if mask.sum() == 0:
|
128 |
+
continue
|
129 |
+
xy = point_tracks_t[mask, t] + 0.5 # [N,2]
|
130 |
+
xy[:, 0] = xy[:, 0].clamp(0, W - 1)
|
131 |
+
xy[:, 1] = xy[:, 1].clamp(0, H - 1)
|
132 |
+
colors_now = colors[mask] # [N,3]
|
133 |
+
N = xy.shape[0]
|
134 |
+
cx = xy[:, 0].long() # [N]
|
135 |
+
cy = xy[:, 1].long()
|
136 |
+
x_grid = cx[:, None, None] + disp_x # [N,D,D]
|
137 |
+
y_grid = cy[:, None, None] + disp_y # [N,D,D]
|
138 |
+
valid = (x_grid >= 0) & (x_grid < W) & (y_grid >= 0) & (y_grid < H)
|
139 |
+
x_valid = x_grid[valid] # [K]
|
140 |
+
y_valid = y_grid[valid]
|
141 |
+
icon_weights = icon.expand(N, D, D)[valid] # [K]
|
142 |
+
colors_valid = colors_now[:, :, None, None].expand(N, 3, D, D).permute(1, 0, 2, 3)[
|
143 |
+
:, valid
|
144 |
+
] # [3, K]
|
145 |
+
idx_flat = (y_valid * W + x_valid).long() # [K]
|
146 |
+
|
147 |
+
accum = torch.zeros_like(frames_t[t]) # [3, H, W]
|
148 |
+
weight = torch.zeros(1, H * W, device=device) # [1, H*W]
|
149 |
+
img_flat = accum.view(C, -1) # [3, H*W]
|
150 |
+
weighted_colors = colors_valid * icon_weights # [3, K]
|
151 |
+
img_flat.scatter_add_(1, idx_flat.unsqueeze(0).expand(C, -1), weighted_colors)
|
152 |
+
weight.scatter_add_(1, idx_flat.unsqueeze(0), icon_weights.unsqueeze(0))
|
153 |
+
weight = weight.view(1, H, W)
|
154 |
+
# accum = accum / (weight + 1e-6) # avoid division by 0
|
155 |
+
# frames_t[t] = torch.where(weight > 0, accum, frames_t[t])
|
156 |
+
# frames_t[t] = frames_t[t] * (1 - weight) + accum
|
157 |
+
|
158 |
+
# alpha = weight.clamp(0, 1)
|
159 |
+
alpha = weight.clamp(0, 1) * 0.75 # transparency
|
160 |
+
accum = accum / (weight + 1e-6) # [3, H, W]
|
161 |
+
frames_t[t] = frames_t[t] * (1 - alpha) + accum * alpha
|
162 |
+
|
163 |
+
# img_flat = frames_t[t].view(C, -1) # [3, H*W]
|
164 |
+
# weighted_colors = colors_valid * icon_weights # [3, K]
|
165 |
+
# img_flat.scatter_add_(1, idx_flat.unsqueeze(0).expand(C, -1), weighted_colors)
|
166 |
+
print('done vis')
|
167 |
+
return frames_t.clamp(0, 255).byte().permute(0, 2, 3, 1).cpu().numpy()
|
168 |
+
|
169 |
+
def paint_point_track_gpu(
|
170 |
+
frames: np.ndarray,
|
171 |
+
point_tracks: np.ndarray,
|
172 |
+
visibles: np.ndarray,
|
173 |
+
colormap: Optional[List[Tuple[int, int, int]]] = None,
|
174 |
+
radius: int = 2,
|
175 |
+
sharpness: float = 0.15,
|
176 |
+
) -> np.ndarray:
|
177 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
178 |
+
# Setup
|
179 |
+
frames_t = torch.from_numpy(frames).float().permute(0, 3, 1, 2).to(device) # [T,C,H,W]
|
180 |
+
point_tracks_t = torch.from_numpy(point_tracks).to(device) # [P,T,2]
|
181 |
+
visibles_t = torch.from_numpy(visibles).to(device) # [P,T]
|
182 |
+
T, C, H, W = frames_t.shape
|
183 |
+
P = point_tracks.shape[0]
|
184 |
+
|
185 |
+
# Colors
|
186 |
+
if colormap is None:
|
187 |
+
colormap = get_colors(P) # or any fixed list of RGB
|
188 |
+
colors = torch.tensor(colormap, dtype=torch.float32, device=device) # [P,3]
|
189 |
+
|
190 |
+
# Icon kernel [K,K]
|
191 |
+
D = radius * 2 + 1
|
192 |
+
y = torch.arange(D, device=device).float()[:, None] - radius - 1
|
193 |
+
x = torch.arange(D, device=device).float()[None, :] - radius - 1
|
194 |
+
dist2 = x**2 + y**2
|
195 |
+
icon = torch.clamp(1 - (dist2 - (radius**2) / 2.0) / (radius * 2 * sharpness), 0, 1) # [D,D]
|
196 |
+
icon = icon.unsqueeze(0) # [1,D,D] for broadcasting
|
197 |
+
|
198 |
+
# Create coordinate grids
|
199 |
+
for t in range(T):
|
200 |
+
image = frames_t[t]
|
201 |
+
# Select visible points
|
202 |
+
visible_mask = visibles_t[:, t]
|
203 |
+
pt_xy = point_tracks_t[visible_mask, t] # [N,2]
|
204 |
+
colors_t = colors[visible_mask] # [N,3]
|
205 |
+
N = pt_xy.shape[0]
|
206 |
+
if N == 0:
|
207 |
+
continue
|
208 |
+
|
209 |
+
# Integer centers
|
210 |
+
pt_xy = pt_xy + 0.5 # correct center offset
|
211 |
+
pt_xy[:, 0] = pt_xy[:, 0].clamp(0, W - 1)
|
212 |
+
pt_xy[:, 1] = pt_xy[:, 1].clamp(0, H - 1)
|
213 |
+
ix = pt_xy[:, 0].long() # [N]
|
214 |
+
iy = pt_xy[:, 1].long()
|
215 |
+
|
216 |
+
# Build grid of indices for patch around each point
|
217 |
+
dx = torch.arange(-radius, radius + 1, device=device)
|
218 |
+
dy = torch.arange(-radius, radius + 1, device=device)
|
219 |
+
dx_grid, dy_grid = torch.meshgrid(dx, dy, indexing='ij')
|
220 |
+
dx_flat = dx_grid.reshape(-1)
|
221 |
+
dy_flat = dy_grid.reshape(-1)
|
222 |
+
patch_x = ix[:, None] + dx_flat[None, :] # [N,K*K]
|
223 |
+
patch_y = iy[:, None] + dy_flat[None, :] # [N,K*K]
|
224 |
+
|
225 |
+
# Mask out-of-bounds
|
226 |
+
valid = (patch_x >= 0) & (patch_x < W) & (patch_y >= 0) & (patch_y < H)
|
227 |
+
flat_idx = (patch_y * W + patch_x).long() # [N,K*K]
|
228 |
+
|
229 |
+
# Flatten icon and colors
|
230 |
+
icon_flat = icon.view(1, -1) # [1, K*K]
|
231 |
+
color_patches = colors_t[:, :, None] * icon_flat[:, None, :] # [N,3,K*K]
|
232 |
+
|
233 |
+
# Flatten to write into 1D image
|
234 |
+
img_flat = image.view(C, -1) # [3, H*W]
|
235 |
+
for i in range(N):
|
236 |
+
valid_mask = valid[i]
|
237 |
+
idxs = flat_idx[i][valid_mask]
|
238 |
+
vals = color_patches[i, :, valid_mask] # [3, valid_count]
|
239 |
+
img_flat[:, idxs] += vals
|
240 |
+
|
241 |
+
out_frames = frames_t.clamp(0, 255).byte().permute(0, 2, 3, 1).cpu().numpy()
|
242 |
+
return out_frames
|
243 |
+
|
244 |
+
|
245 |
+
def paint_point_track_parallel(
|
246 |
+
frames: np.ndarray,
|
247 |
+
point_tracks: np.ndarray,
|
248 |
+
visibles: np.ndarray,
|
249 |
+
colormap: Optional[List[Tuple[int, int, int]]] = None,
|
250 |
+
max_workers: int = 8,
|
251 |
+
) -> np.ndarray:
|
252 |
+
num_points, num_frames = point_tracks.shape[:2]
|
253 |
+
if colormap is None:
|
254 |
+
colormap = get_colors(num_colors=num_points)
|
255 |
+
height, width = frames.shape[1:3]
|
256 |
+
radius = 2
|
257 |
+
print('radius', radius)
|
258 |
+
diam = radius * 2 + 1
|
259 |
+
# Precompute the icon and its bilinear components
|
260 |
+
quadratic_y = np.square(np.arange(diam)[:, np.newaxis] - radius - 1)
|
261 |
+
quadratic_x = np.square(np.arange(diam)[np.newaxis, :] - radius - 1)
|
262 |
+
icon = (quadratic_y + quadratic_x) - (radius**2) / 2.0
|
263 |
+
sharpness = 0.15
|
264 |
+
icon = np.clip(icon / (radius * 2 * sharpness), 0, 1)
|
265 |
+
icon = 1 - icon[:, :, np.newaxis]
|
266 |
+
icon1 = np.pad(icon, [(0, 1), (0, 1), (0, 0)])
|
267 |
+
icon2 = np.pad(icon, [(1, 0), (0, 1), (0, 0)])
|
268 |
+
icon3 = np.pad(icon, [(0, 1), (1, 0), (0, 0)])
|
269 |
+
icon4 = np.pad(icon, [(1, 0), (1, 0), (0, 0)])
|
270 |
+
|
271 |
+
def draw_point(image, i, t):
|
272 |
+
if not visibles[i, t]:
|
273 |
+
return
|
274 |
+
x, y = point_tracks[i, t, :] + 0.5
|
275 |
+
x = min(max(x, 0.0), width)
|
276 |
+
y = min(max(y, 0.0), height)
|
277 |
+
x1, y1 = np.floor(x).astype(np.int32), np.floor(y).astype(np.int32)
|
278 |
+
x2, y2 = x1 + 1, y1 + 1
|
279 |
+
patch = (
|
280 |
+
icon1 * (x2 - x) * (y2 - y)
|
281 |
+
+ icon2 * (x2 - x) * (y - y1)
|
282 |
+
+ icon3 * (x - x1) * (y2 - y)
|
283 |
+
+ icon4 * (x - x1) * (y - y1)
|
284 |
+
)
|
285 |
+
x_ub = x1 + 2 * radius + 2
|
286 |
+
y_ub = y1 + 2 * radius + 2
|
287 |
+
image[y1:y_ub, x1:x_ub, :] = (1 - patch) * image[y1:y_ub, x1:x_ub, :] + patch * np.array(colormap[i])[np.newaxis, np.newaxis, :]
|
288 |
+
|
289 |
+
video = frames.copy()
|
290 |
+
for t in range(num_frames):
|
291 |
+
image = np.pad(
|
292 |
+
video[t],
|
293 |
+
[(radius + 1, radius + 1), (radius + 1, radius + 1), (0, 0)],
|
294 |
+
)
|
295 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
296 |
+
futures = [executor.submit(draw_point, image, i, t) for i in range(num_points)]
|
297 |
+
_ = [f.result() for f in futures] # wait for all threads
|
298 |
+
video[t] = image[radius + 1 : -radius - 1, radius + 1 : -radius - 1].astype(np.uint8)
|
299 |
+
|
300 |
+
return video
|
301 |
+
|
302 |
+
|
303 |
def paint_point_track(
|
304 |
frames: np.ndarray,
|
305 |
point_tracks: np.ndarray,
|
|
|
537 |
def track(
|
538 |
video_preview,
|
539 |
video_input,
|
540 |
+
video_fps,
|
541 |
+
query_frame,
|
542 |
query_points,
|
543 |
query_points_color,
|
544 |
query_count,
|
|
|
546 |
# tracking_mode = 'selected'
|
547 |
# if query_count == 0:
|
548 |
# tracking_mode = 'grid'
|
549 |
+
|
550 |
+
# print('query_frames', query_frames)
|
551 |
+
# query_frame = int(query_frames[0])
|
552 |
+
# # query_frame = 0
|
553 |
|
554 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
555 |
dtype = torch.float if device == "cuda" else torch.float
|
|
|
619 |
# add_support_grid=True
|
620 |
|
621 |
|
622 |
+
# query_frame = 0
|
623 |
|
624 |
torch.cuda.empty_cache()
|
625 |
|
|
|
656 |
# colors.extend(frame_colors)
|
657 |
# colors = np.array(colors)
|
658 |
|
659 |
+
# traj_maps_e = traj_maps_e[:,:,:,::4,::4] # subsample
|
660 |
+
# visconf_maps_e = visconf_maps_e[:,:,:,::4,::4] # subsample
|
661 |
+
traj_maps_e = traj_maps_e[:,:,:,::2,::2] # subsample
|
662 |
+
visconf_maps_e = visconf_maps_e[:,:,:,::2,::2] # subsample
|
663 |
|
664 |
tracks = traj_maps_e.permute(0,3,4,1,2).reshape(-1,T,2).numpy()
|
665 |
+
visibs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
|
666 |
+
confs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
|
667 |
+
|
668 |
+
visibs = (visibs * confs) > 0.9 # N,T
|
669 |
+
|
670 |
|
671 |
# sc = (np.array([video_preview.shape[2], video_preview.shape[1]]) / np.array([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]])).reshape(1,1,2)
|
672 |
# print('sc', sc)
|
|
|
685 |
for frame_colors in query_points_color:
|
686 |
colors.extend(frame_colors)
|
687 |
colors = np.array(colors)
|
688 |
+
|
689 |
+
inds = np.sum(visibs * 1.0, axis=1) >= min(T//4,3)
|
690 |
+
tracks = tracks[inds]
|
691 |
+
visibs = visibs[inds]
|
692 |
+
colors = colors[inds]
|
693 |
|
694 |
+
# painted_video = paint_point_track_parallel(video_preview,tracks,visibs,colors)
|
695 |
+
# painted_video = paint_point_track_gpu(video_preview,tracks,visibs,colors)
|
696 |
+
painted_video = paint_point_track_gpu_scatter(video_preview,tracks,visibs,colors)
|
697 |
print("7 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
698 |
|
699 |
# save video
|
|
|
771 |
|
772 |
with gr.Row():
|
773 |
current_frame = gr.Image(
|
774 |
+
# label="Click to add query points",
|
775 |
+
label="Query frame",
|
776 |
type="numpy",
|
777 |
interactive=False
|
778 |
)
|
|
|
905 |
video_preview,
|
906 |
video_input,
|
907 |
video_fps,
|
908 |
+
query_frames,
|
909 |
query_points,
|
910 |
query_points_color,
|
911 |
query_count,
|