Spaces:
Runtime error
Runtime error
File size: 5,837 Bytes
9867d34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
"""Feature extraction utilities for video and text processing."""
import os
import numpy as np
import torch
import av
from PIL import Image
from einops import rearrange
from typing import Any, Dict, List, Union, Tuple
from loguru import logger
from .config_utils import AttributeDict
from ..constants import FPS_VISUAL, MAX_VIDEO_DURATION_SECONDS
class FeatureExtractionError(Exception):
"""Exception raised for feature extraction errors."""
pass
def get_frames_av(
video_path: str,
fps: float,
max_length: float = None,
) -> Tuple[np.ndarray, float]:
end_sec = max_length if max_length is not None else 15
next_frame_time_for_each_fps = 0.0
time_delta_for_each_fps = 1 / fps
all_frames = []
output_frames = []
with av.open(video_path) as container:
stream = container.streams.video[0]
ori_fps = stream.guessed_rate
stream.thread_type = "AUTO"
for packet in container.demux(stream):
for frame in packet.decode():
frame_time = frame.time
if frame_time < 0:
continue
if frame_time > end_sec:
break
frame_np = None
this_time = frame_time
while this_time >= next_frame_time_for_each_fps:
if frame_np is None:
frame_np = frame.to_ndarray(format="rgb24")
output_frames.append(frame_np)
next_frame_time_for_each_fps += time_delta_for_each_fps
output_frames = np.stack(output_frames)
vid_len_in_s = len(output_frames) / fps
if max_length is not None and len(output_frames) > int(max_length * fps):
output_frames = output_frames[: int(max_length * fps)]
vid_len_in_s = max_length
return output_frames, vid_len_in_s
@torch.inference_mode()
def encode_video_with_siglip2(x: torch.Tensor, model_dict, batch_size: int = -1):
b, t, c, h, w = x.shape
if batch_size < 0:
batch_size = b * t
x = rearrange(x, "b t c h w -> (b t) c h w")
outputs = []
for i in range(0, b * t, batch_size):
outputs.append(model_dict.siglip2_model.get_image_features(pixel_values=x[i : i + batch_size]))
res = torch.cat(outputs, dim=0)
res = rearrange(res, "(b t) d -> b t d", b=b)
return res
@torch.inference_mode()
def encode_video_with_sync(x: torch.Tensor, model_dict, batch_size: int = -1):
"""
The input video of x is best to be in fps of 24 of greater than 24.
Input:
x: tensor in shape of [B, T, C, H, W]
batch_size: the batch_size for synchformer inference
"""
b, t, c, h, w = x.shape
assert c == 3 and h == 224 and w == 224
segment_size = 16
step_size = 8
num_segments = (t - segment_size) // step_size + 1
segments = []
for i in range(num_segments):
segments.append(x[:, i * step_size : i * step_size + segment_size])
x = torch.stack(segments, dim=1).cuda() # (B, num_segments, segment_size, 3, 224, 224)
outputs = []
if batch_size < 0:
batch_size = b * num_segments
x = rearrange(x, "b s t c h w -> (b s) 1 t c h w")
for i in range(0, b * num_segments, batch_size):
with torch.autocast(device_type="cuda", enabled=True, dtype=torch.half):
outputs.append(model_dict.syncformer_model(x[i : i + batch_size]))
x = torch.cat(outputs, dim=0) # [b * num_segments, 1, 8, 768]
x = rearrange(x, "(b s) 1 t d -> b (s t) d", b=b)
return x
@torch.inference_mode()
def encode_video_features(video_path, model_dict):
visual_features = {}
# siglip2 visual features
frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["siglip2"])
images = [Image.fromarray(frame).convert('RGB') for frame in frames]
images = [model_dict.siglip2_preprocess(image) for image in images] # [T, C, H, W]
clip_frames = torch.stack(images).to(model_dict.device).unsqueeze(0)
visual_features['siglip2_feat'] = encode_video_with_siglip2(clip_frames, model_dict).to(model_dict.device)
# synchformer visual features
frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["synchformer"])
images = torch.from_numpy(frames).permute(0, 3, 1, 2) # [T, C, H, W]
sync_frames = model_dict.syncformer_preprocess(images).unsqueeze(0) # [1, T, 3, 224, 224]
# [1, num_segments * 8, channel_dim], e.g. [1, 240, 768] for 10s video
visual_features['syncformer_feat'] = encode_video_with_sync(sync_frames, model_dict)
vid_len_in_s = sync_frames.shape[1] / FPS_VISUAL["synchformer"]
visual_features = AttributeDict(visual_features)
return visual_features, vid_len_in_s
@torch.inference_mode()
def encode_text_feat(text: List[str], model_dict):
# x: (B, L)
inputs = model_dict.clap_tokenizer(text, padding=True, return_tensors="pt").to(model_dict.device)
outputs = model_dict.clap_model(**inputs, output_hidden_states=True, return_dict=True)
return outputs.last_hidden_state, outputs.attentions
def feature_process(video_path, prompt, model_dict, cfg):
visual_feats, audio_len_in_s = encode_video_features(video_path, model_dict)
neg_prompt = "noisy, harsh"
prompts = [neg_prompt, prompt]
text_feat_res, text_feat_mask = encode_text_feat(prompts, model_dict)
text_feat = text_feat_res[1:]
uncond_text_feat = text_feat_res[:1]
if cfg.model_config.model_kwargs.text_length < text_feat.shape[1]:
text_seq_length = cfg.model_config.model_kwargs.text_length
text_feat = text_feat[:, :text_seq_length]
uncond_text_feat = uncond_text_feat[:, :text_seq_length]
text_feats = AttributeDict({
'text_feat': text_feat,
'uncond_text_feat': uncond_text_feat,
})
return visual_feats, text_feats, audio_len_in_s
|