Spaces:
Paused
Paused
| import torch | |
| from transformers import TextStreamer | |
| import numpy as np | |
| import os | |
| import json | |
| import torch | |
| import numpy as np | |
| import base64 | |
| from PIL import Image | |
| from io import BytesIO | |
| import matplotlib.pyplot as plt | |
| from torchvision.transforms import Compose, Lambda, ToTensor | |
| from torchvision import transforms | |
| from transformers import ProcessorMixin, BatchEncoding | |
| from transformers.image_processing_utils import BatchFeature | |
| from pytorchvideo.data.encoded_video import EncodedVideo | |
| from torchvision.transforms import Compose, Lambda, ToTensor | |
| from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo, CenterCropVideo | |
| from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample | |
| def load_frames(frames_dir): | |
| results = [] | |
| frame_names = os.listdir(frames_dir) | |
| frame_names.sort() | |
| for frame_name in frame_names: | |
| image_path = f"{frames_dir}/{frame_name}" | |
| results.append(image_path) | |
| return results | |
| def sample_frames(frames, num_segments): | |
| duration = len(frames) | |
| frame_id_array = np.linspace(0, duration-1, num_segments, dtype=int) | |
| frame_id_list = frame_id_array.tolist() | |
| sampled_frames = [] | |
| for frame_idx in frame_id_list: | |
| single_frame_path = frames[frame_idx] | |
| sampled_frames.append(single_frame_path) | |
| return sampled_frames | |
| class VideoProcessor: | |
| def __init__(self, image_transform): | |
| self.image_transform = image_transform | |
| def __call__(self, video_path, transform=None, | |
| video_decode_backend='opencv', | |
| clip_start_sec=0.0, clip_end_sec=None, | |
| num_frames=50, **kwargs): | |
| if transform is None: transform = self.image_transform | |
| if video_decode_backend == 'pytorchvideo': | |
| # decord pyav | |
| video = EncodedVideo.from_path(video_path, decoder="decord", decode_audio=False) | |
| duration = video.duration | |
| start_sec = clip_start_sec # secs | |
| end_sec = clip_end_sec if clip_end_sec is not None else duration # secs | |
| video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec) | |
| video_outputs = transform(video_data) | |
| elif video_decode_backend == 'decord': | |
| import decord | |
| from decord import VideoReader, cpu | |
| decord.bridge.set_bridge('torch') | |
| decord_vr = VideoReader(video_path, ctx=cpu(0)) | |
| ori_duration = len(decord_vr) | |
| # frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int) | |
| fps_vid = decord_vr.get_avg_fps() | |
| valid_duration = min(int(fps_vid * 10), ori_duration) | |
| frame_id_list = np.linspace(0, valid_duration-1, num_frames, dtype=int) | |
| video_data = decord_vr.get_batch(frame_id_list) | |
| video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W) | |
| video_outputs = transform(video_data) | |
| elif video_decode_backend == 'opencv': | |
| import cv2 | |
| cv2_vr = cv2.VideoCapture(video_path) | |
| duration = int(cv2_vr.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int) | |
| video_data = [] | |
| for frame_idx in frame_id_list: | |
| cv2_vr.set(1, frame_idx) | |
| _, frame = cv2_vr.read() | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| video_data.append(torch.from_numpy(frame).permute(2, 0, 1)) | |
| cv2_vr.release() | |
| video_data = torch.stack(video_data, dim=1) | |
| video_outputs = transform(video_data) | |
| elif video_decode_backend == 'frames': | |
| # FIXME does not input start and end clip timestamps. Require duration info to deal with. | |
| frames = load_frames(video_path) | |
| frames = sample_frames(frames, num_frames) | |
| to_tensor = ToTensor() | |
| video_data = torch.stack([to_tensor(_) for _ in frames]).permute(1, 0, 2, 3) # (T, C, H, W) -> (C, T, H, W) | |
| else: | |
| raise NameError('video_decode_backend should specify in (pytorchvideo, decord, opencv, frames)') | |