Spaces:
Running
on
Zero
Running
on
Zero
# this is built from https://huggingface.co/spaces/facebook/cotracker/blob/main/app.py | |
# which was built from https://github.com/cvlab-kaist/locotrack/blob/main/demo/demo.py | |
import os | |
import sys | |
import uuid | |
from concurrent.futures import ThreadPoolExecutor | |
import subprocess | |
from nets.blocks import InputPadder | |
import gradio as gr | |
import mediapy | |
import numpy as np | |
import cv2 | |
import matplotlib | |
import torch | |
import colorsys | |
import random | |
from typing import List, Optional, Sequence, Tuple | |
import spaces | |
import numpy as np | |
import utils.basic | |
import utils.improc | |
import PIL.Image | |
# Generate random colormaps for visualizing different points. | |
def get_colors(num_colors: int) -> List[Tuple[int, int, int]]: | |
"""Gets colormap for points.""" | |
colors = [] | |
for i in np.arange(0.0, 360.0, 360.0 / num_colors): | |
hue = i / 360.0 | |
lightness = (50 + np.random.rand() * 10) / 100.0 | |
saturation = (90 + np.random.rand() * 10) / 100.0 | |
color = colorsys.hls_to_rgb(hue, lightness, saturation) | |
colors.append( | |
(int(color[0] * 255), int(color[1] * 255), int(color[2] * 255)) | |
) | |
random.shuffle(colors) | |
return colors | |
# def get_points_on_a_grid( | |
# size: int, | |
# extent: Tuple[float, ...], | |
# center: Optional[Tuple[float, ...]] = None, | |
# device: Optional[torch.device] = torch.device("cpu"), | |
# ): | |
# r"""Get a grid of points covering a rectangular region | |
# `get_points_on_a_grid(size, extent)` generates a :attr:`size` by | |
# :attr:`size` grid fo points distributed to cover a rectangular area | |
# specified by `extent`. | |
# The `extent` is a pair of integer :math:`(H,W)` specifying the height | |
# and width of the rectangle. | |
# Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)` | |
# specifying the vertical and horizontal center coordinates. The center | |
# defaults to the middle of the extent. | |
# Points are distributed uniformly within the rectangle leaving a margin | |
# :math:`m=W/64` from the border. | |
# It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of | |
# points :math:`P_{ij}=(x_i, y_i)` where | |
# .. math:: | |
# P_{ij} = \left( | |
# c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~ | |
# c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i | |
# \right) | |
# Points are returned in row-major order. | |
# Args: | |
# size (int): grid size. | |
# extent (tuple): height and with of the grid extent. | |
# center (tuple, optional): grid center. | |
# device (str, optional): Defaults to `"cpu"`. | |
# Returns: | |
# Tensor: grid. | |
# """ | |
# if size == 1: | |
# return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None] | |
# if center is None: | |
# center = [extent[0] / 2, extent[1] / 2] | |
# margin = extent[1] / 64 | |
# range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin) | |
# range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin) | |
# grid_y, grid_x = torch.meshgrid( | |
# torch.linspace(*range_y, size, device=device), | |
# torch.linspace(*range_x, size, device=device), | |
# indexing="ij", | |
# ) | |
# return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2) | |
def paint_point_track_gpu_scatter( | |
frames: np.ndarray, | |
point_tracks: np.ndarray, | |
visibles: np.ndarray, | |
colormap: Optional[List[Tuple[int, int, int]]] = None, | |
rate: int = 1, | |
# sharpness: float = 0.1, | |
) -> np.ndarray: | |
print('starting vis') | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
frames_t = torch.from_numpy(frames).float().permute(0, 3, 1, 2).to(device) # [T,C,H,W] | |
frames_t = frames_t * 0.5 # darken, to see the point tracks better | |
point_tracks_t = torch.from_numpy(point_tracks).to(device) # [P,T,2] | |
visibles_t = torch.from_numpy(visibles).to(device) # [P,T] | |
T, C, H, W = frames_t.shape | |
P = point_tracks.shape[0] | |
if colormap is None: | |
colormap = get_colors(P) | |
colors = torch.tensor(colormap, dtype=torch.float32, device=device) # [P,3] | |
if rate==1: | |
radius = 1 | |
elif rate==2: | |
radius = 1 | |
elif rate== 4: | |
radius = 2 | |
elif rate== 8: | |
radius = 4 | |
else: | |
radius = 6 | |
# radius = max(1, int(np.sqrt(rate))) | |
sharpness = 0.15 + 0.05 * np.log2(rate) | |
D = radius * 2 + 1 | |
y = torch.arange(D, device=device).float()[:, None] - radius | |
x = torch.arange(D, device=device).float()[None, :] - radius | |
dist2 = x**2 + y**2 | |
icon = torch.clamp(1 - (dist2 - (radius**2) / 2.0) / (radius * 2 * sharpness), 0, 1) # [D,D] | |
icon = icon.view(1, D, D) | |
dx = torch.arange(-radius, radius + 1, device=device) | |
dy = torch.arange(-radius, radius + 1, device=device) | |
disp_y, disp_x = torch.meshgrid(dy, dx, indexing="ij") # [D,D] | |
for t in range(T): | |
mask = visibles_t[:, t] # [P] | |
if mask.sum() == 0: | |
continue | |
xy = point_tracks_t[mask, t] + 0.5 # [N,2] | |
xy[:, 0] = xy[:, 0].clamp(0, W - 1) | |
xy[:, 1] = xy[:, 1].clamp(0, H - 1) | |
colors_now = colors[mask] # [N,3] | |
N = xy.shape[0] | |
cx = xy[:, 0].long() # [N] | |
cy = xy[:, 1].long() | |
x_grid = cx[:, None, None] + disp_x # [N,D,D] | |
y_grid = cy[:, None, None] + disp_y # [N,D,D] | |
valid = (x_grid >= 0) & (x_grid < W) & (y_grid >= 0) & (y_grid < H) | |
x_valid = x_grid[valid] # [K] | |
y_valid = y_grid[valid] | |
icon_weights = icon.expand(N, D, D)[valid] # [K] | |
colors_valid = colors_now[:, :, None, None].expand(N, 3, D, D).permute(1, 0, 2, 3)[ | |
:, valid | |
] # [3, K] | |
idx_flat = (y_valid * W + x_valid).long() # [K] | |
accum = torch.zeros_like(frames_t[t]) # [3, H, W] | |
weight = torch.zeros(1, H * W, device=device) # [1, H*W] | |
img_flat = accum.view(C, -1) # [3, H*W] | |
weighted_colors = colors_valid * icon_weights # [3, K] | |
img_flat.scatter_add_(1, idx_flat.unsqueeze(0).expand(C, -1), weighted_colors) | |
weight.scatter_add_(1, idx_flat.unsqueeze(0), icon_weights.unsqueeze(0)) | |
weight = weight.view(1, H, W) | |
# accum = accum / (weight + 1e-6) # avoid division by 0 | |
# frames_t[t] = torch.where(weight > 0, accum, frames_t[t]) | |
# frames_t[t] = frames_t[t] * (1 - weight) + accum | |
# alpha = weight.clamp(0, 1) | |
# alpha = weight.clamp(0, 1) * 0.9 # transparency | |
alpha = weight.clamp(0, 1) # transparency | |
accum = accum / (weight + 1e-6) # [3, H, W] | |
frames_t[t] = frames_t[t] * (1 - alpha) + accum * alpha | |
# img_flat = frames_t[t].view(C, -1) # [3, H*W] | |
# weighted_colors = colors_valid * icon_weights # [3, K] | |
# img_flat.scatter_add_(1, idx_flat.unsqueeze(0).expand(C, -1), weighted_colors) | |
print('done vis') | |
return frames_t.clamp(0, 255).byte().permute(0, 2, 3, 1).cpu().numpy() | |
def paint_point_track_gpu( | |
frames: np.ndarray, | |
point_tracks: np.ndarray, | |
visibles: np.ndarray, | |
colormap: Optional[List[Tuple[int, int, int]]] = None, | |
radius: int = 2, | |
sharpness: float = 0.15, | |
) -> np.ndarray: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Setup | |
frames_t = torch.from_numpy(frames).float().permute(0, 3, 1, 2).to(device) # [T,C,H,W] | |
point_tracks_t = torch.from_numpy(point_tracks).to(device) # [P,T,2] | |
visibles_t = torch.from_numpy(visibles).to(device) # [P,T] | |
T, C, H, W = frames_t.shape | |
P = point_tracks.shape[0] | |
# Colors | |
if colormap is None: | |
colormap = get_colors(P) # or any fixed list of RGB | |
colors = torch.tensor(colormap, dtype=torch.float32, device=device) # [P,3] | |
# Icon kernel [K,K] | |
D = radius * 2 + 1 | |
y = torch.arange(D, device=device).float()[:, None] - radius - 1 | |
x = torch.arange(D, device=device).float()[None, :] - radius - 1 | |
dist2 = x**2 + y**2 | |
icon = torch.clamp(1 - (dist2 - (radius**2) / 2.0) / (radius * 2 * sharpness), 0, 1) # [D,D] | |
icon = icon.unsqueeze(0) # [1,D,D] for broadcasting | |
# Create coordinate grids | |
for t in range(T): | |
image = frames_t[t] | |
# Select visible points | |
visible_mask = visibles_t[:, t] | |
pt_xy = point_tracks_t[visible_mask, t] # [N,2] | |
colors_t = colors[visible_mask] # [N,3] | |
N = pt_xy.shape[0] | |
if N == 0: | |
continue | |
# Integer centers | |
pt_xy = pt_xy + 0.5 # correct center offset | |
pt_xy[:, 0] = pt_xy[:, 0].clamp(0, W - 1) | |
pt_xy[:, 1] = pt_xy[:, 1].clamp(0, H - 1) | |
ix = pt_xy[:, 0].long() # [N] | |
iy = pt_xy[:, 1].long() | |
# Build grid of indices for patch around each point | |
dx = torch.arange(-radius, radius + 1, device=device) | |
dy = torch.arange(-radius, radius + 1, device=device) | |
dx_grid, dy_grid = torch.meshgrid(dx, dy, indexing='ij') | |
dx_flat = dx_grid.reshape(-1) | |
dy_flat = dy_grid.reshape(-1) | |
patch_x = ix[:, None] + dx_flat[None, :] # [N,K*K] | |
patch_y = iy[:, None] + dy_flat[None, :] # [N,K*K] | |
# Mask out-of-bounds | |
valid = (patch_x >= 0) & (patch_x < W) & (patch_y >= 0) & (patch_y < H) | |
flat_idx = (patch_y * W + patch_x).long() # [N,K*K] | |
# Flatten icon and colors | |
icon_flat = icon.view(1, -1) # [1, K*K] | |
color_patches = colors_t[:, :, None] * icon_flat[:, None, :] # [N,3,K*K] | |
# Flatten to write into 1D image | |
img_flat = image.view(C, -1) # [3, H*W] | |
for i in range(N): | |
valid_mask = valid[i] | |
idxs = flat_idx[i][valid_mask] | |
vals = color_patches[i, :, valid_mask] # [3, valid_count] | |
img_flat[:, idxs] += vals | |
out_frames = frames_t.clamp(0, 255).byte().permute(0, 2, 3, 1).cpu().numpy() | |
return out_frames | |
def paint_point_track_parallel( | |
frames: np.ndarray, | |
point_tracks: np.ndarray, | |
visibles: np.ndarray, | |
colormap: Optional[List[Tuple[int, int, int]]] = None, | |
max_workers: int = 8, | |
) -> np.ndarray: | |
num_points, num_frames = point_tracks.shape[:2] | |
if colormap is None: | |
colormap = get_colors(num_colors=num_points) | |
height, width = frames.shape[1:3] | |
radius = 1 | |
print('radius', radius) | |
diam = radius * 2 + 1 | |
# Precompute the icon and its bilinear components | |
quadratic_y = np.square(np.arange(diam)[:, np.newaxis] - radius - 1) | |
quadratic_x = np.square(np.arange(diam)[np.newaxis, :] - radius - 1) | |
icon = (quadratic_y + quadratic_x) - (radius**2) / 2.0 | |
sharpness = 0.15 | |
icon = np.clip(icon / (radius * 2 * sharpness), 0, 1) | |
icon = 1 - icon[:, :, np.newaxis] | |
icon1 = np.pad(icon, [(0, 1), (0, 1), (0, 0)]) | |
icon2 = np.pad(icon, [(1, 0), (0, 1), (0, 0)]) | |
icon3 = np.pad(icon, [(0, 1), (1, 0), (0, 0)]) | |
icon4 = np.pad(icon, [(1, 0), (1, 0), (0, 0)]) | |
def draw_point(image, i, t): | |
if not visibles[i, t]: | |
return | |
x, y = point_tracks[i, t, :] + 0.5 | |
x = min(max(x, 0.0), width) | |
y = min(max(y, 0.0), height) | |
x1, y1 = np.floor(x).astype(np.int32), np.floor(y).astype(np.int32) | |
x2, y2 = x1 + 1, y1 + 1 | |
patch = ( | |
icon1 * (x2 - x) * (y2 - y) | |
+ icon2 * (x2 - x) * (y - y1) | |
+ icon3 * (x - x1) * (y2 - y) | |
+ icon4 * (x - x1) * (y - y1) | |
) | |
x_ub = x1 + 2 * radius + 2 | |
y_ub = y1 + 2 * radius + 2 | |
image[y1:y_ub, x1:x_ub, :] = (1 - patch) * image[y1:y_ub, x1:x_ub, :] + patch * np.array(colormap[i])[np.newaxis, np.newaxis, :] | |
video = frames.copy() | |
for t in range(num_frames): | |
image = np.pad( | |
video[t], | |
[(radius + 1, radius + 1), (radius + 1, radius + 1), (0, 0)], | |
) | |
with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
futures = [executor.submit(draw_point, image, i, t) for i in range(num_points)] | |
_ = [f.result() for f in futures] # wait for all threads | |
video[t] = image[radius + 1 : -radius - 1, radius + 1 : -radius - 1].astype(np.uint8) | |
return video | |
def paint_point_track( | |
frames: np.ndarray, | |
point_tracks: np.ndarray, | |
visibles: np.ndarray, | |
colormap: Optional[List[Tuple[int, int, int]]] = None, | |
) -> np.ndarray: | |
"""Converts a sequence of points to color code video. | |
Args: | |
frames: [num_frames, height, width, 3], np.uint8, [0, 255] | |
point_tracks: [num_points, num_frames, 2], np.float32, [0, width / height] | |
visibles: [num_points, num_frames], bool | |
colormap: colormap for points, each point has a different RGB color. | |
Returns: | |
video: [num_frames, height, width, 3], np.uint8, [0, 255] | |
""" | |
num_points, num_frames = point_tracks.shape[0:2] | |
if colormap is None: | |
colormap = get_colors(num_colors=num_points) | |
height, width = frames.shape[1:3] | |
dot_size_as_fraction_of_min_edge = 0.015 | |
# radius = int(round(min(height, width) * dot_size_as_fraction_of_min_edge)) | |
radius = 2 | |
# print('radius', radius) | |
diam = radius * 2 + 1 | |
quadratic_y = np.square(np.arange(diam)[:, np.newaxis] - radius - 1) | |
quadratic_x = np.square(np.arange(diam)[np.newaxis, :] - radius - 1) | |
icon = (quadratic_y + quadratic_x) - (radius**2) / 2.0 | |
sharpness = 0.15 | |
icon = np.clip(icon / (radius * 2 * sharpness), 0, 1) | |
icon = 1 - icon[:, :, np.newaxis] | |
icon1 = np.pad(icon, [(0, 1), (0, 1), (0, 0)]) | |
icon2 = np.pad(icon, [(1, 0), (0, 1), (0, 0)]) | |
icon3 = np.pad(icon, [(0, 1), (1, 0), (0, 0)]) | |
icon4 = np.pad(icon, [(1, 0), (1, 0), (0, 0)]) | |
video = frames.copy() | |
for t in range(num_frames): | |
# Pad so that points that extend outside the image frame don't crash us | |
image = np.pad( | |
video[t], | |
[ | |
(radius + 1, radius + 1), | |
(radius + 1, radius + 1), | |
(0, 0), | |
], | |
) | |
for i in range(num_points): | |
# The icon is centered at the center of a pixel, but the input coordinates | |
# are raster coordinates. Therefore, to render a point at (1,1) (which | |
# lies on the corner between four pixels), we need 1/4 of the icon placed | |
# centered on the 0'th row, 0'th column, etc. We need to subtract | |
# 0.5 to make the fractional position come out right. | |
x, y = point_tracks[i, t, :] + 0.5 | |
x = min(max(x, 0.0), width) | |
y = min(max(y, 0.0), height) | |
if visibles[i, t]: | |
x1, y1 = np.floor(x).astype(np.int32), np.floor(y).astype(np.int32) | |
x2, y2 = x1 + 1, y1 + 1 | |
# bilinear interpolation | |
patch = ( | |
icon1 * (x2 - x) * (y2 - y) | |
+ icon2 * (x2 - x) * (y - y1) | |
+ icon3 * (x - x1) * (y2 - y) | |
+ icon4 * (x - x1) * (y - y1) | |
) | |
x_ub = x1 + 2 * radius + 2 | |
y_ub = y1 + 2 * radius + 2 | |
image[y1:y_ub, x1:x_ub, :] = (1 - patch) * image[ | |
y1:y_ub, x1:x_ub, : | |
] + patch * np.array(colormap[i])[np.newaxis, np.newaxis, :] | |
# Remove the pad | |
video[t] = image[ | |
radius + 1 : -radius - 1, radius + 1 : -radius - 1 | |
].astype(np.uint8) | |
return video | |
PREVIEW_WIDTH = 1024 # Width of the preview video | |
PREVIEW_HEIGHT = 1024 | |
# VIDEO_INPUT_RESO = (384, 512) # Resolution of the input video | |
POINT_SIZE = 1 # Size of the query point in the preview video | |
FRAME_LIMIT = 600 # Limit the number of frames to process | |
# def get_point(frame_num, video_queried_preview, query_points, query_points_color, query_count, evt: gr.SelectData): | |
# print(f"You selected {(evt.index[0], evt.index[1], frame_num)}") | |
# current_frame = video_queried_preview[int(frame_num)] | |
# # Get the mouse click | |
# query_points[int(frame_num)].append((evt.index[0], evt.index[1], frame_num)) | |
# # Choose the color for the point from matplotlib colormap | |
# color = matplotlib.colormaps.get_cmap("gist_rainbow")(query_count % 20 / 20) | |
# color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255)) | |
# # print(f"Color: {color}") | |
# query_points_color[int(frame_num)].append(color) | |
# # Draw the point on the frame | |
# x, y = evt.index | |
# current_frame_draw = cv2.circle(current_frame, (x, y), POINT_SIZE, color, -1) | |
# # Update the frame | |
# video_queried_preview[int(frame_num)] = current_frame_draw | |
# # Update the query count | |
# query_count += 1 | |
# return ( | |
# current_frame_draw, # Updated frame for preview | |
# video_queried_preview, # Updated preview video | |
# query_points, # Updated query points | |
# query_points_color, # Updated query points color | |
# query_count # Updated query count | |
# ) | |
# def undo_point(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count): | |
# if len(query_points[int(frame_num)]) == 0: | |
# return ( | |
# video_queried_preview[int(frame_num)], | |
# video_queried_preview, | |
# query_points, | |
# query_points_color, | |
# query_count | |
# ) | |
# # Get the last point | |
# query_points[int(frame_num)].pop(-1) | |
# query_points_color[int(frame_num)].pop(-1) | |
# # Redraw the frame | |
# current_frame_draw = video_preview[int(frame_num)].copy() | |
# for point, color in zip(query_points[int(frame_num)], query_points_color[int(frame_num)]): | |
# x, y, _ = point | |
# current_frame_draw = cv2.circle(current_frame_draw, (x, y), POINT_SIZE, color, -1) | |
# # Update the query count | |
# query_count -= 1 | |
# # Update the frame | |
# video_queried_preview[int(frame_num)] = current_frame_draw | |
# return ( | |
# current_frame_draw, # Updated frame for preview | |
# video_queried_preview, # Updated preview video | |
# query_points, # Updated query points | |
# query_points_color, # Updated query points color | |
# query_count # Updated query count | |
# ) | |
# def clear_frame_fn(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count): | |
# query_count -= len(query_points[int(frame_num)]) | |
# query_points[int(frame_num)] = [] | |
# query_points_color[int(frame_num)] = [] | |
# video_queried_preview[int(frame_num)] = video_preview[int(frame_num)].copy() | |
# return ( | |
# video_preview[int(frame_num)], # Set the preview frame to the original frame | |
# video_queried_preview, | |
# query_points, # Cleared query points | |
# query_points_color, # Cleared query points color | |
# query_count # New query count | |
# ) | |
# def clear_all_fn(frame_num, video_preview): | |
# return ( | |
# video_preview[int(frame_num)], | |
# video_preview.copy(), | |
# [[] for _ in range(len(video_preview))], | |
# [[] for _ in range(len(video_preview))], | |
# 0 | |
# ) | |
def choose_frame(frame_num, video_preview_array): | |
return video_preview_array[int(frame_num)] | |
def choose_rate1(video_preview, video_fps, tracks, visibs): | |
return choose_rate(1, video_preview, video_fps, tracks, visibs) | |
def choose_rate2(video_preview, video_fps, tracks, visibs): | |
return choose_rate(2, video_preview, video_fps, tracks, visibs) | |
def choose_rate4(video_preview, video_fps, tracks, visibs): | |
return choose_rate(4, video_preview, video_fps, tracks, visibs) | |
def choose_rate8(video_preview, video_fps, tracks, visibs): | |
return choose_rate(8, video_preview, video_fps, tracks, visibs) | |
# def choose_rate16(video_preview, video_fps, tracks, visibs): | |
# return choose_rate(16, video_preview, video_fps, tracks, visibs) | |
def choose_rate(rate, video_preview, video_fps, tracks, visibs): | |
print('rate', rate) | |
print('video_preview', video_preview.shape) | |
T, H, W,_ = video_preview.shape | |
tracks_ = tracks.reshape(H,W,T,2)[::rate,::rate].reshape(-1,T,2) | |
visibs_ = visibs.reshape(H,W,T)[::rate,::rate].reshape(-1,T) | |
return paint_video(video_preview, video_fps, tracks_, visibs_, rate=rate) | |
# return video_preview_array[int(frame_num)] | |
def preprocess_video_input(video_path): | |
video_arr = mediapy.read_video(video_path) | |
video_fps = video_arr.metadata.fps | |
num_frames = video_arr.shape[0] | |
if num_frames > FRAME_LIMIT: | |
gr.Warning(f"The video is too long. Only the first {FRAME_LIMIT} frames will be used.", duration=5) | |
video_arr = video_arr[:FRAME_LIMIT] | |
num_frames = FRAME_LIMIT | |
height, width = video_arr.shape[1:3] | |
if height > width: | |
new_height, new_width = PREVIEW_HEIGHT, int(PREVIEW_WIDTH * width / height) | |
else: | |
new_height, new_width = int(PREVIEW_WIDTH * height / width), PREVIEW_WIDTH | |
if height*width > 768*1024: | |
new_height = new_height*3//4 | |
new_width = new_width*3//4 | |
new_height, new_width = new_height//16 * 16, new_width//16 * 16 # make it divisible by 16, partly to satisfy ffmpeg | |
preview_video = mediapy.resize_video(video_arr, (new_height, new_width)) | |
# input_video = mediapy.resize_video(video_arr, VIDEO_INPUT_RESO) | |
# input_video = video_arr | |
input_video = preview_video | |
preview_video = np.array(preview_video) | |
input_video = np.array(input_video) | |
interactive = True | |
return ( | |
video_arr, # Original video | |
preview_video, # Original preview video, resized for faster processing | |
preview_video.copy(), # Copy of preview video for visualization | |
input_video, # Resized video input for model | |
# None, # video_feature, # Extracted feature | |
video_fps, # Set the video FPS | |
# gr.update(open=True), # open/close the video input drawer | |
# tracking_mode, # Set the tracking mode | |
preview_video[0], # Set the preview frame to the first frame | |
gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=interactive), # Set slider interactive | |
[[] for _ in range(num_frames)], # Set query_points to empty | |
[[] for _ in range(num_frames)], # Set query_points_color to empty | |
[[] for _ in range(num_frames)], | |
0, # Set query count to 0 | |
gr.update(interactive=interactive), # Make the buttons interactive | |
gr.update(interactive=interactive), | |
gr.update(interactive=interactive), | |
gr.update(interactive=True), | |
# gr.update(interactive=True), | |
# gr.update(interactive=True), | |
# gr.update(interactive=True), | |
# gr.update(interactive=True), | |
) | |
def paint_video(video_preview, video_fps, tracks, visibs, rate=1): | |
print('video_preview', video_preview.shape) | |
T, H, W, _ = video_preview.shape | |
query_count = tracks.shape[0] | |
cmap = matplotlib.colormaps.get_cmap("gist_rainbow") | |
query_points_color = [[]] | |
for i in range(query_count): | |
# Choose the color for the point from matplotlib colormap | |
color = cmap(i / float(query_count)) | |
color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255)) | |
query_points_color[0].append(color) | |
# make color array | |
colors = [] | |
for frame_colors in query_points_color: | |
colors.extend(frame_colors) | |
colors = np.array(colors) | |
painted_video = paint_point_track_gpu_scatter(video_preview,tracks,visibs,colors,rate=rate)#=max(rate//2,1)) | |
# save video | |
video_file_name = uuid.uuid4().hex + ".mp4" | |
video_path = os.path.join(os.path.dirname(__file__), "tmp") | |
video_file_path = os.path.join(video_path, video_file_name) | |
os.makedirs(video_path, exist_ok=True) | |
if False: | |
mediapy.write_video(video_file_path, painted_video, fps=video_fps) | |
else: | |
for ti in range(T): | |
temp_out_f = '%s/%03d.jpg' % (video_path, ti) | |
# temp_out_f = '%s/%03d.png' % (video_path, ti) | |
im = PIL.Image.fromarray(painted_video[ti]) | |
# im.save(temp_out_f, "PNG", subsampling=0, quality=80) | |
im.save(temp_out_f) | |
print('saved', temp_out_f) | |
# os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate %d -pattern_type glob -i "%s/*.png" -c:v libx264 -crf 20 -pix_fmt yuv420p %s' % (video_fps, video_path, video_file_path)) | |
os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate %d -pattern_type glob -i "%s/*.jpg" -c:v libx264 -crf 20 -pix_fmt yuv420p %s' % (video_fps, video_path, video_file_path)) | |
print('saved', video_file_path) | |
for ti in range(T): | |
# temp_out_f = '%s/%03d.png' % (video_path, ti) | |
temp_out_f = '%s/%03d.jpg' % (video_path, ti) | |
os.remove(temp_out_f) | |
print('deleted', temp_out_f) | |
return video_file_path | |
def track( | |
video_preview, | |
video_input, | |
video_fps, | |
query_frame, | |
query_points, | |
query_points_color, | |
query_count, | |
): | |
# tracking_mode = 'selected' | |
# if query_count == 0: | |
# tracking_mode = 'grid' | |
# print('query_frames', query_frames) | |
# query_frame = int(query_frames[0]) | |
# # query_frame = 0 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float if device == "cuda" else torch.float | |
print("0 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) | |
# # Convert query points to tensor, normalize to input resolution | |
# if tracking_mode!='grid': | |
# query_points_tensor = [] | |
# for frame_points in query_points: | |
# query_points_tensor.extend(frame_points) | |
# query_points_tensor = torch.tensor(query_points_tensor).float() | |
# query_points_tensor *= torch.tensor([ | |
# VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0], 1 | |
# ]) / torch.tensor([ | |
# [video_preview.shape[2], video_preview.shape[1], 1] | |
# ]) | |
# query_points_tensor = query_points_tensor[None].flip(-1).to(device, dtype) # xyt -> tyx | |
# query_points_tensor = query_points_tensor[:, :, [0, 2, 1]] # tyx -> txy | |
video_input = torch.tensor(video_input).unsqueeze(0).to(dtype) | |
print("1 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) | |
# model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_online") | |
# model = model.to(device) | |
from nets.alltracker import Net | |
model = Net(16) | |
url = "https://huggingface.co/aharley/alltracker/resolve/main/alltracker.pth" | |
state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu') | |
model.load_state_dict(state_dict['model'], strict=True) | |
print('loaded weights from', url) | |
model = model.to(device) | |
print("2 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) | |
video_input = video_input.permute(0, 1, 4, 2, 3) | |
print('video_input', video_input.shape) | |
# model(video_input, iters=4, sw=None, is_training=False) | |
# # model(video_chunk=video_input, is_first_step=True, grid_size=0, queries=queries, add_support_grid=add_support_grid) | |
_, T, _, H, W = video_input.shape | |
utils.basic.print_stats('video_input', video_input) | |
print("3 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) | |
grid_xy = utils.basic.gridcloud2d(1, H, W, norm=False, device='cpu:0').float() # 1,H*W,2 | |
grid_xy = grid_xy.permute(0,2,1).reshape(1,1,2,H,W) # 1,1,2,H,W | |
print("4 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) | |
# if tracking_mode=='grid': | |
# xy = get_points_on_a_grid(15, video_input.shape[3:], device=device) | |
# queries = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) # | |
# add_support_grid=False | |
# cmap = matplotlib.colormaps.get_cmap("gist_rainbow") | |
# query_points_color = [[]] | |
# query_count = queries.shape[1] | |
# for i in range(query_count): | |
# # Choose the color for the point from matplotlib colormap | |
# color = cmap(i / float(query_count)) | |
# color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255)) | |
# query_points_color[0].append(color) | |
# else: | |
# queries = query_points_tensor | |
# add_support_grid=True | |
# query_frame = 0 | |
torch.cuda.empty_cache() | |
with torch.no_grad(): | |
utils.basic.print_stats('video_input', video_input) | |
if query_frame < T-1: | |
flows_e, visconf_maps_e, _, _ = \ | |
model(video_input[:, query_frame:], iters=4, sw=None, is_training=False) | |
traj_maps_e = flows_e.cpu() + grid_xy # B,Tf,2,H,W | |
visconf_maps_e = visconf_maps_e.cpu() | |
else: | |
traj_maps_e = torch.zeros((1,0,2,H,W), dtype=torch.float32) | |
visconf_maps_e = torch.zeros((1,0,2,H,W), dtype=torch.float32) | |
if query_frame > 0: | |
backward_flows_e, backward_visconf_maps_e, _, _ = \ | |
model(video_input[:, :query_frame+1].flip([1]), iters=4, sw=None, is_training=False) | |
backward_traj_maps_e = backward_flows_e.cpu() + grid_xy # B,Tb,2,H,W, reversed | |
backward_visconf_maps_e = backward_visconf_maps_e.cpu() | |
backward_traj_maps_e = backward_traj_maps_e.flip([1]) # flip time | |
backward_visconf_maps_e = backward_visconf_maps_e.flip([1]) # flip time | |
if query_frame < T-1: | |
backward_traj_maps_e = backward_traj_maps_e[:, :-1] # drop the overlapped frame | |
backward_visconf_maps_e = backward_visconf_maps_e[:, :-1] # drop the overlapped frame | |
traj_maps_e = torch.cat([backward_traj_maps_e, traj_maps_e], dim=1) # B,T,2,H,W | |
visconf_maps_e = torch.cat([backward_visconf_maps_e, visconf_maps_e], dim=1) # B,T,2,H,W | |
# if query_frame < T-1: | |
# flows_e, visconf_maps_e, _, _ = \ | |
# model.forward_sliding(video_input[:, query_frame:], iters=4, sw=None, is_training=False) | |
# traj_maps_e = flows_e + grid_xy # B,Tf,2,H,W | |
# print("5 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) | |
# else: | |
# traj_maps_e = torch.zeros((1,0,2,H,W), dtype=torch.float32) | |
# visconf_maps_e = torch.zeros((1,0,2,H,W), dtype=torch.float32) | |
# if query_frame > 0: | |
# backward_flows_e, backward_visconf_maps_e, _, _ = \ | |
# model.forward_sliding(video_input[:, :query_frame+1].flip([1]), iters=4, sw=None, is_training=False) | |
# backward_traj_maps_e = backward_flows_e + grid_xy # B,Tb,2,H,W, reversed | |
# backward_traj_maps_e = backward_traj_maps_e.flip([1]) # flip time | |
# backward_visconf_maps_e = backward_visconf_maps_e.flip([1]) # flip time | |
# if query_frame < T-1: | |
# backward_traj_maps_e = backward_traj_maps_e[:, :-1] # drop the overlapped frame | |
# backward_visconf_maps_e = backward_visconf_maps_e[:, :-1] # drop the overlapped frame | |
# traj_maps_e = torch.cat([backward_traj_maps_e, traj_maps_e], dim=1) # B,T,2,H,W | |
# visconf_maps_e = torch.cat([backward_visconf_maps_e, visconf_maps_e], dim=1) # B,T,2,H,W | |
print("6 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) | |
# for ind in range(0, video_input.shape[1] - model.step, model.step): | |
# pred_tracks, pred_visibility = model( | |
# video_chunk=video_input[:, ind : ind + model.step * 2], | |
# grid_size=0, | |
# queries=queries, | |
# add_support_grid=add_support_grid | |
# ) # B T N 2, B T N 1 | |
# tracks = (pred_tracks * torch.tensor([video_preview.shape[2], video_preview.shape[1]]).to(device) / torch.tensor([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]]).to(device))[0].permute(1, 0, 2).cpu().numpy() | |
# pred_occ = pred_visibility[0].permute(1, 0).cpu().numpy() | |
# # make color array | |
# colors = [] | |
# for frame_colors in query_points_color: | |
# colors.extend(frame_colors) | |
# colors = np.array(colors) | |
# traj_maps_e = traj_maps_e[:,:,:,::4,::4] # subsample | |
# visconf_maps_e = visconf_maps_e[:,:,:,::4,::4] # subsample | |
# traj_maps_e = traj_maps_e[:,:,:,::2,::2] # subsample | |
# visconf_maps_e = visconf_maps_e[:,:,:,::2,::2] # subsample | |
tracks = traj_maps_e.permute(0,3,4,1,2).reshape(-1,T,2).numpy() | |
visibs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy() | |
confs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy() | |
visibs = (visibs * confs) > 0.3 # N,T | |
# visibs = (confs) > 0.1 # N,T | |
# sc = (np.array([video_preview.shape[2], video_preview.shape[1]]) / np.array([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]])).reshape(1,1,2) | |
# print('sc', sc) | |
# tracks = tracks * sc | |
return paint_video(video_preview, video_fps, tracks, visibs), tracks, visibs, gr.update(interactive=True, value=1) | |
# gr.update(interactive=True), | |
# gr.update(interactive=True), | |
# gr.update(interactive=True), | |
# gr.update(interactive=True), | |
# gr.update(interactive=True)) | |
# # query_count = tracks.shape[0] | |
# query_count = tracks.shape[0] | |
# cmap = matplotlib.colormaps.get_cmap("gist_rainbow") | |
# query_points_color = [[]] | |
# for i in range(query_count): | |
# # Choose the color for the point from matplotlib colormap | |
# color = cmap(i / float(query_count)) | |
# color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255)) | |
# query_points_color[0].append(color) | |
# # make color array | |
# colors = [] | |
# for frame_colors in query_points_color: | |
# colors.extend(frame_colors) | |
# colors = np.array(colors) | |
# # visibs_ = visibs * 1.0 | |
# # visibs_ = visibs_[:,1:] * visibs_[:,:-1] | |
# # inds = np.sum(visibs_, axis=1) >= min(T//4,8) | |
# # tracks = tracks[inds] | |
# # visibs = visibs[inds] | |
# # colors = colors[inds] | |
# # painted_video = paint_point_track_parallel(video_preview,tracks,visibs,colors) | |
# # painted_video = paint_point_track_gpu(video_preview,tracks,visibs,colors) | |
# painted_video = paint_point_track_gpu_scatter(video_preview,tracks,visibs,colors) | |
# print("7 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) | |
# # save video | |
# video_file_name = uuid.uuid4().hex + ".mp4" | |
# video_path = os.path.join(os.path.dirname(__file__), "tmp") | |
# video_file_path = os.path.join(video_path, video_file_name) | |
# os.makedirs(video_path, exist_ok=True) | |
# if False: | |
# mediapy.write_video(video_file_path, painted_video, fps=video_fps) | |
# else: | |
# for ti in range(T): | |
# temp_out_f = '%s/%03d.jpg' % (video_path, ti) | |
# # temp_out_f = '%s/%03d.png' % (video_path, ti) | |
# im = PIL.Image.fromarray(painted_video[ti]) | |
# # im.save(temp_out_f, "PNG", subsampling=0, quality=80) | |
# im.save(temp_out_f) | |
# print('saved', temp_out_f) | |
# # os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate %d -pattern_type glob -i "%s/*.png" -c:v libx264 -crf 20 -pix_fmt yuv420p %s' % (video_fps, video_path, video_file_path)) | |
# os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate %d -pattern_type glob -i "%s/*.jpg" -c:v libx264 -crf 20 -pix_fmt yuv420p %s' % (video_fps, video_path, video_file_path)) | |
# print('saved', video_file_path) | |
# for ti in range(T): | |
# # temp_out_f = '%s/%03d.png' % (video_path, ti) | |
# temp_out_f = '%s/%03d.jpg' % (video_path, ti) | |
# os.remove(temp_out_f) | |
# print('deleted', temp_out_f) | |
# # out_file = tempfile.NamedTemporaryFile(suffix="out.mp4", delete=False) | |
# # subprocess.run(f"ffmpeg -y -loglevel quiet -stats -i {painted_video} -c:v libx264 {out_file.name}".split()) | |
# return video_file_path | |
with gr.Blocks() as demo: | |
video = gr.State() | |
video_queried_preview = gr.State() | |
video_preview = gr.State() | |
video_input = gr.State() | |
video_fps = gr.State(24) | |
query_points = gr.State([]) | |
query_points_color = gr.State([]) | |
is_tracked_query = gr.State([]) | |
query_count = gr.State(0) | |
# rate = gr.State([]) | |
tracks = gr.State([]) | |
visibs = gr.State([]) | |
gr.Markdown("# ⚡ AllTracker: Efficient Dense Point Tracking at High Resolution") | |
gr.Markdown("<div style='text-align: left;'> \ | |
<p>Welcome to <a href='https://alltracker.github.io/' target='_blank'>AllTracker</a>! This demo runs our model to perform all-pixel tracking in a video of your choice.</p> \ | |
<p>To get started, simply upload your <b>.mp4</b> video, or click on one of the example videos. The shorter the video, the faster the processing. We recommend submitting videos under 20 seconds long.</p> \ | |
<p>After picking a video, click \"Submit\" to load the frames into the app, and optionally choose a frame (using the slider), and then click \"Track\".</p> \ | |
<p>For full info on how this works, check out our <a href='https://github.com/aharley/alltracker/' target='_blank'>GitHub Repo</a>!</p> \ | |
<p>Initial code for this Gradio app came from LocoTrack and CoTracker -- big thanks to those authors!</p> \ | |
</div>" | |
) | |
gr.Markdown("## Step 1: Select a video, and click \"Submit\".") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
video_in = gr.Video(label="Video input", format="mp4") | |
with gr.Row(): | |
submit = gr.Button("Submit") | |
with gr.Column(): | |
# with gr.Accordion("Sample videos", open=True) as video_in_drawer: | |
with gr.Row(): | |
butterfly = os.path.join(os.path.dirname(__file__), "videos", "butterfly_800.mp4") | |
monkey = os.path.join(os.path.dirname(__file__), "videos", "monkey_800.mp4") | |
groundbox = os.path.join(os.path.dirname(__file__), "videos", "groundbox_800.mp4") | |
apple = os.path.join(os.path.dirname(__file__), "videos", "apple.mp4") | |
grasp_sponge_800 = os.path.join(os.path.dirname(__file__), "videos", "grasp_sponge_800.mp4") | |
twist = os.path.join(os.path.dirname(__file__), "videos", "twist_800.mp4") | |
# dog = os.path.join(os.path.dirname(__file__), "videos", "dog.mp4") | |
bear = os.path.join(os.path.dirname(__file__), "videos", "bear.mp4") | |
paragliding_launch = os.path.join(os.path.dirname(__file__), "videos", "paragliding-launch.mp4") | |
paragliding = os.path.join(os.path.dirname(__file__), "videos", "paragliding.mp4") | |
cat = os.path.join(os.path.dirname(__file__), "videos", "cat.mp4") | |
pillow = os.path.join(os.path.dirname(__file__), "videos", "pillow.mp4") | |
teddy = os.path.join(os.path.dirname(__file__), "videos", "teddy.mp4") | |
backpack = os.path.join(os.path.dirname(__file__), "videos", "backpack.mp4") | |
gr.Examples(examples=[butterfly, groundbox, monkey, grasp_sponge_800, bear, apple, paragliding, paragliding_launch, cat, pillow, teddy, backpack, twist], | |
inputs = [ | |
video_in | |
], | |
examples_per_page=20, | |
) | |
# with gr.Column(): | |
# gr.Markdown("Choose a video or upload one of your own.") | |
gr.Markdown("## Step 2: Select a frame, and click \"Track\".") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
query_frame_slider = gr.Slider( | |
minimum=0, maximum=100, value=0, step=1, label="Choose frame", interactive=False) | |
# with gr.Row(): | |
# undo = gr.Button("Undo", interactive=False) | |
# clear_frame = gr.Button("Clear Frame", interactive=False) | |
# clear_all = gr.Button("Clear All", interactive=False) | |
with gr.Row(): | |
current_frame = gr.Image( | |
# label="Click to add query points", | |
label="Query frame", | |
type="numpy", | |
interactive=False | |
) | |
with gr.Row(): | |
track_button = gr.Button("Track", interactive=False) | |
with gr.Column(): | |
# with gr.Row(): | |
# rate1_button = gr.Button("Subsampling", interactive=False) | |
# rate2_button = gr.Button("Stride 2", interactive=False) | |
# rate4_button = gr.Button("Rate 4", interactive=False) | |
# rate8_button = gr.Button("Rate 8", interactive=False) | |
# # rate16_button = gr.Button("Rate 16", interactive=False) | |
with gr.Row(): | |
# rate_slider = gr.Slider( | |
# minimum=1, maximum=16, value=1, step=1, label="Choose subsampling rate", interactive=False) | |
rate_radio = gr.Radio([1, 2, 4, 8, 16], value=1, label="Choose visualization subsampling", interactive=False) | |
with gr.Row(): | |
output_video = gr.Video( | |
label="Output video", | |
interactive=False, | |
autoplay=True, | |
loop=True, | |
) | |
submit.click( | |
fn = preprocess_video_input, | |
inputs = [video_in], | |
outputs = [ | |
video, | |
video_preview, | |
video_queried_preview, | |
video_input, | |
video_fps, | |
# video_in_drawer, | |
current_frame, | |
query_frame_slider, | |
query_points, | |
query_points_color, | |
is_tracked_query, | |
query_count, | |
# undo, | |
# clear_frame, | |
# clear_all, | |
track_button, | |
], | |
queue = False | |
) | |
query_frame_slider.change( | |
fn = choose_frame, | |
inputs = [query_frame_slider, video_queried_preview], | |
outputs = [ | |
current_frame, | |
], | |
queue = False | |
) | |
# current_frame.select( | |
# fn = get_point, | |
# inputs = [ | |
# query_frames, | |
# video_queried_preview, | |
# query_points, | |
# query_points_color, | |
# query_count, | |
# ], | |
# outputs = [ | |
# current_frame, | |
# video_queried_preview, | |
# query_points, | |
# query_points_color, | |
# query_count | |
# ], | |
# queue = False | |
# ) | |
# undo.click( | |
# fn = undo_point, | |
# inputs = [ | |
# query_frames, | |
# video_preview, | |
# video_queried_preview, | |
# query_points, | |
# query_points_color, | |
# query_count | |
# ], | |
# outputs = [ | |
# current_frame, | |
# video_queried_preview, | |
# query_points, | |
# query_points_color, | |
# query_count | |
# ], | |
# queue = False | |
# ) | |
# clear_frame.click( | |
# fn = clear_frame_fn, | |
# inputs = [ | |
# query_frames, | |
# video_preview, | |
# video_queried_preview, | |
# query_points, | |
# query_points_color, | |
# query_count | |
# ], | |
# outputs = [ | |
# current_frame, | |
# video_queried_preview, | |
# query_points, | |
# query_points_color, | |
# query_count | |
# ], | |
# queue = False | |
# ) | |
# clear_all.click( | |
# fn = clear_all_fn, | |
# inputs = [ | |
# query_frames, | |
# video_preview, | |
# ], | |
# outputs = [ | |
# current_frame, | |
# video_queried_preview, | |
# query_points, | |
# query_points_color, | |
# query_count | |
# ], | |
# queue = False | |
# ) | |
# output_video = None | |
track_button.click( | |
fn = track, | |
inputs = [ | |
video_preview, | |
video_input, | |
video_fps, | |
query_frame_slider, | |
query_points, | |
query_points_color, | |
query_count, | |
], | |
outputs = [ | |
output_video, | |
tracks, | |
visibs, | |
rate_radio, | |
# rate1_button, | |
# rate2_button, | |
# rate4_button, | |
# rate8_button, | |
# rate16_button, | |
], | |
queue = True, | |
) | |
# rate_slider.change( | |
# fn = choose_rate, | |
# inputs = [rate_slider, video_preview, video_fps, tracks, visibs], | |
# outputs = [ | |
# output_video, | |
# ], | |
# queue = False | |
# ) | |
rate_radio.change( | |
fn = choose_rate, | |
inputs = [rate_radio, video_preview, video_fps, tracks, visibs], | |
outputs = [ | |
output_video, | |
], | |
queue = False | |
) | |
# rate1_button.click( | |
# fn = choose_rate1, | |
# inputs = [video_preview, video_fps, tracks, visibs], | |
# outputs = [output_video], | |
# queue = False, | |
# ) | |
# rate2_button.click( | |
# fn = choose_rate2, | |
# inputs = [video_preview, video_fps, tracks, visibs], | |
# outputs = [output_video], | |
# queue = False, | |
# ) | |
# rate4_button.click( | |
# fn = choose_rate4, | |
# inputs = [video_preview, video_fps, tracks, visibs], | |
# outputs = [output_video], | |
# queue = False, | |
# ) | |
# rate8_button.click( | |
# fn = choose_rate8, | |
# inputs = [video_preview, video_fps, tracks, visibs], | |
# outputs = [output_video], | |
# queue = False, | |
# ) | |
# rate16_button.click( | |
# fn = choose_rate16, | |
# inputs = [video_preview, video_fps, tracks, visibs], | |
# outputs = [output_video], | |
# queue = False, | |
# ) | |
# demo.launch(show_api=False, show_error=True, debug=False, share=False) | |
# demo.launch(show_api=False, show_error=True, debug=False, share=True) | |
demo.launch(show_api=False, show_error=True, debug=False, share=False) | |