Spaces:
Paused
Paused
import logging | |
from pathlib import Path | |
from typing import List, Tuple | |
import cv2 | |
import torch | |
from torchvision.transforms.functional import resize | |
from einops import repeat, rearrange | |
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error | |
# Very few bug reports but it happens. Look in decord Github issues for more relevant information. | |
import decord # isort:skip | |
decord.bridge.set_bridge("torch") | |
from PIL import Image | |
import numpy as np | |
import pdb | |
########## loaders ########## | |
def load_prompts(prompt_path: Path) -> List[str]: | |
with open(prompt_path, "r", encoding="utf-8") as file: | |
return [line.strip() for line in file.readlines() if len(line.strip()) > 0] | |
def load_videos(video_path: Path) -> List[Path]: | |
with open(video_path, "r", encoding="utf-8") as file: | |
return [video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0] | |
def load_images(image_path: Path) -> List[Path]: | |
with open(image_path, "r", encoding="utf-8") as file: | |
return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0] | |
def load_images_from_videos(videos_path: List[Path]) -> List[Path]: | |
first_frames_dir = videos_path[0].parent.parent / "first_frames" | |
first_frames_dir.mkdir(exist_ok=True) | |
first_frame_paths = [] | |
for video_path in videos_path: | |
frame_path = first_frames_dir / f"{video_path.stem}.png" | |
if frame_path.exists(): | |
first_frame_paths.append(frame_path) | |
continue | |
# Open video | |
cap = cv2.VideoCapture(str(video_path)) | |
# Read first frame | |
ret, frame = cap.read() | |
if not ret: | |
raise RuntimeError(f"Failed to read video: {video_path}") | |
# Save frame as PNG with same name as video | |
cv2.imwrite(str(frame_path), frame) | |
logging.info(f"Saved first frame to {frame_path}") | |
# Release video capture | |
cap.release() | |
first_frame_paths.append(frame_path) | |
return first_frame_paths | |
def load_binary_mask_compressed(path, shape, device, dtype): | |
# shape: (F,C,H,W), C=1 | |
with open(path, 'rb') as f: | |
packed = np.frombuffer(f.read(), dtype=np.uint8) | |
unpacked = np.unpackbits(packed)[:np.prod(shape)] | |
mask_loaded = torch.from_numpy(unpacked).to(device, dtype).reshape(shape) | |
mask_interp = torch.nn.functional.interpolate(rearrange(mask_loaded, 'f c h w -> c f h w').unsqueeze(0), size=(shape[0]//4+1, shape[2]//8, shape[3]//8), mode='trilinear', align_corners=False).squeeze(0) # CFHW | |
mask_interp[mask_interp>=0.5] = 1.0 | |
mask_interp[mask_interp<0.5] = 0.0 | |
return rearrange(mask_loaded, 'f c h w -> c f h w'), mask_interp | |
########## preprocessors ########## | |
def preprocess_image_with_resize( | |
image_path: Path | str, | |
height: int, | |
width: int, | |
) -> torch.Tensor: | |
""" | |
Loads and resizes a single image. | |
Args: | |
image_path: Path to the image file. | |
height: Target height for resizing. | |
width: Target width for resizing. | |
Returns: | |
torch.Tensor: Image tensor with shape [C, H, W] where: | |
C = number of channels (3 for RGB) | |
H = height | |
W = width | |
""" | |
if isinstance(image_path, str): | |
image_path = Path(image_path) | |
# image = cv2.imread(image_path.as_posix()) | |
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# image = cv2.resize(image, (width, height)) | |
# image = torch.from_numpy(image).float() | |
# image = image.permute(2, 0, 1).contiguous() | |
image = np.array(Image.open(image_path.as_posix()).resize((width, height))) | |
image = torch.from_numpy(image).float() | |
image = image.permute(2, 0, 1).contiguous() | |
return image | |
def preprocess_video_with_resize( | |
video_path: Path | str, | |
max_num_frames: int, | |
height: int, | |
width: int, | |
) -> torch.Tensor: | |
""" | |
Loads and resizes a single video. | |
The function processes the video through these steps: | |
1. If video frame count > max_num_frames, downsample frames evenly | |
2. If video dimensions don't match (height, width), resize frames | |
Args: | |
video_path: Path to the video file. | |
max_num_frames: Maximum number of frames to keep. | |
height: Target height for resizing. | |
width: Target width for resizing. | |
Returns: | |
A torch.Tensor with shape [F, C, H, W] where: | |
F = number of frames | |
C = number of channels (3 for RGB) | |
H = height | |
W = width | |
""" | |
if isinstance(video_path, str): | |
video_path = Path(video_path) | |
video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height) | |
video_num_frames = len(video_reader) | |
if video_num_frames < max_num_frames: | |
# Get all frames first | |
frames = video_reader.get_batch(list(range(video_num_frames))) | |
# Repeat the last frame until we reach max_num_frames | |
last_frame = frames[-1:] | |
num_repeats = max_num_frames - video_num_frames | |
repeated_frames = last_frame.repeat(num_repeats, 1, 1, 1) | |
frames = torch.cat([frames, repeated_frames], dim=0) | |
return frames.float().permute(0, 3, 1, 2).contiguous() | |
else: | |
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames)) | |
frames = video_reader.get_batch(indices) | |
import pdb | |
pdb.set_trace() | |
frames = frames[:max_num_frames].float() | |
frames = frames.permute(0, 3, 1, 2).contiguous() | |
return frames | |
def preprocess_video_with_buckets( | |
video_path: Path, | |
resolution_buckets: List[Tuple[int, int, int]], | |
) -> torch.Tensor: | |
""" | |
Args: | |
video_path: Path to the video file. | |
resolution_buckets: List of tuples (num_frames, height, width) representing | |
available resolution buckets. | |
Returns: | |
torch.Tensor: Video tensor with shape [F, C, H, W] where: | |
F = number of frames | |
C = number of channels (3 for RGB) | |
H = height | |
W = width | |
The function processes the video through these steps: | |
1. Finds nearest frame bucket <= video frame count | |
2. Downsamples frames evenly to match bucket size | |
3. Finds nearest resolution bucket based on dimensions | |
4. Resizes frames to match bucket resolution | |
""" | |
video_reader = decord.VideoReader(uri=video_path.as_posix()) | |
video_num_frames = len(video_reader) | |
resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames] | |
if len(resolution_buckets) == 0: | |
raise ValueError(f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}") | |
nearest_frame_bucket = min( | |
resolution_buckets, | |
key=lambda bucket: video_num_frames - bucket[0], | |
default=1, | |
)[0] | |
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) | |
frames = video_reader.get_batch(frame_indices) | |
frames = frames[:nearest_frame_bucket].float() | |
frames = frames.permute(0, 3, 1, 2).contiguous() | |
nearest_res = min(resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3])) | |
nearest_res = (nearest_res[1], nearest_res[2]) | |
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0) | |
return frames | |