James Zhou
[init]
9867d34
"""Feature extraction utilities for video and text processing."""
import os
import numpy as np
import torch
import av
from PIL import Image
from einops import rearrange
from typing import Any, Dict, List, Union, Tuple
from loguru import logger
from .config_utils import AttributeDict
from ..constants import FPS_VISUAL, MAX_VIDEO_DURATION_SECONDS
class FeatureExtractionError(Exception):
"""Exception raised for feature extraction errors."""
pass
def get_frames_av(
video_path: str,
fps: float,
max_length: float = None,
) -> Tuple[np.ndarray, float]:
end_sec = max_length if max_length is not None else 15
next_frame_time_for_each_fps = 0.0
time_delta_for_each_fps = 1 / fps
all_frames = []
output_frames = []
with av.open(video_path) as container:
stream = container.streams.video[0]
ori_fps = stream.guessed_rate
stream.thread_type = "AUTO"
for packet in container.demux(stream):
for frame in packet.decode():
frame_time = frame.time
if frame_time < 0:
continue
if frame_time > end_sec:
break
frame_np = None
this_time = frame_time
while this_time >= next_frame_time_for_each_fps:
if frame_np is None:
frame_np = frame.to_ndarray(format="rgb24")
output_frames.append(frame_np)
next_frame_time_for_each_fps += time_delta_for_each_fps
output_frames = np.stack(output_frames)
vid_len_in_s = len(output_frames) / fps
if max_length is not None and len(output_frames) > int(max_length * fps):
output_frames = output_frames[: int(max_length * fps)]
vid_len_in_s = max_length
return output_frames, vid_len_in_s
@torch.inference_mode()
def encode_video_with_siglip2(x: torch.Tensor, model_dict, batch_size: int = -1):
b, t, c, h, w = x.shape
if batch_size < 0:
batch_size = b * t
x = rearrange(x, "b t c h w -> (b t) c h w")
outputs = []
for i in range(0, b * t, batch_size):
outputs.append(model_dict.siglip2_model.get_image_features(pixel_values=x[i : i + batch_size]))
res = torch.cat(outputs, dim=0)
res = rearrange(res, "(b t) d -> b t d", b=b)
return res
@torch.inference_mode()
def encode_video_with_sync(x: torch.Tensor, model_dict, batch_size: int = -1):
"""
The input video of x is best to be in fps of 24 of greater than 24.
Input:
x: tensor in shape of [B, T, C, H, W]
batch_size: the batch_size for synchformer inference
"""
b, t, c, h, w = x.shape
assert c == 3 and h == 224 and w == 224
segment_size = 16
step_size = 8
num_segments = (t - segment_size) // step_size + 1
segments = []
for i in range(num_segments):
segments.append(x[:, i * step_size : i * step_size + segment_size])
x = torch.stack(segments, dim=1).cuda() # (B, num_segments, segment_size, 3, 224, 224)
outputs = []
if batch_size < 0:
batch_size = b * num_segments
x = rearrange(x, "b s t c h w -> (b s) 1 t c h w")
for i in range(0, b * num_segments, batch_size):
with torch.autocast(device_type="cuda", enabled=True, dtype=torch.half):
outputs.append(model_dict.syncformer_model(x[i : i + batch_size]))
x = torch.cat(outputs, dim=0) # [b * num_segments, 1, 8, 768]
x = rearrange(x, "(b s) 1 t d -> b (s t) d", b=b)
return x
@torch.inference_mode()
def encode_video_features(video_path, model_dict):
visual_features = {}
# siglip2 visual features
frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["siglip2"])
images = [Image.fromarray(frame).convert('RGB') for frame in frames]
images = [model_dict.siglip2_preprocess(image) for image in images] # [T, C, H, W]
clip_frames = torch.stack(images).to(model_dict.device).unsqueeze(0)
visual_features['siglip2_feat'] = encode_video_with_siglip2(clip_frames, model_dict).to(model_dict.device)
# synchformer visual features
frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["synchformer"])
images = torch.from_numpy(frames).permute(0, 3, 1, 2) # [T, C, H, W]
sync_frames = model_dict.syncformer_preprocess(images).unsqueeze(0) # [1, T, 3, 224, 224]
# [1, num_segments * 8, channel_dim], e.g. [1, 240, 768] for 10s video
visual_features['syncformer_feat'] = encode_video_with_sync(sync_frames, model_dict)
vid_len_in_s = sync_frames.shape[1] / FPS_VISUAL["synchformer"]
visual_features = AttributeDict(visual_features)
return visual_features, vid_len_in_s
@torch.inference_mode()
def encode_text_feat(text: List[str], model_dict):
# x: (B, L)
inputs = model_dict.clap_tokenizer(text, padding=True, return_tensors="pt").to(model_dict.device)
outputs = model_dict.clap_model(**inputs, output_hidden_states=True, return_dict=True)
return outputs.last_hidden_state, outputs.attentions
def feature_process(video_path, prompt, model_dict, cfg):
visual_feats, audio_len_in_s = encode_video_features(video_path, model_dict)
neg_prompt = "noisy, harsh"
prompts = [neg_prompt, prompt]
text_feat_res, text_feat_mask = encode_text_feat(prompts, model_dict)
text_feat = text_feat_res[1:]
uncond_text_feat = text_feat_res[:1]
if cfg.model_config.model_kwargs.text_length < text_feat.shape[1]:
text_seq_length = cfg.model_config.model_kwargs.text_length
text_feat = text_feat[:, :text_seq_length]
uncond_text_feat = uncond_text_feat[:, :text_seq_length]
text_feats = AttributeDict({
'text_feat': text_feat,
'uncond_text_feat': uncond_text_feat,
})
return visual_feats, text_feats, audio_len_in_s