File size: 5,837 Bytes
9867d34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""Feature extraction utilities for video and text processing."""

import os
import numpy as np
import torch
import av
from PIL import Image
from einops import rearrange
from typing import Any, Dict, List, Union, Tuple
from loguru import logger

from .config_utils import AttributeDict
from ..constants import FPS_VISUAL, MAX_VIDEO_DURATION_SECONDS


class FeatureExtractionError(Exception):
    """Exception raised for feature extraction errors."""
    pass

def get_frames_av(
    video_path: str,
    fps: float,
    max_length: float = None,
) -> Tuple[np.ndarray, float]:
    end_sec = max_length if max_length is not None else 15
    next_frame_time_for_each_fps = 0.0
    time_delta_for_each_fps = 1 / fps

    all_frames = []
    output_frames = []

    with av.open(video_path) as container:
        stream = container.streams.video[0]
        ori_fps = stream.guessed_rate
        stream.thread_type = "AUTO"
        for packet in container.demux(stream):
            for frame in packet.decode():
                frame_time = frame.time
                if frame_time < 0:
                    continue
                if frame_time > end_sec:
                    break

                frame_np = None

                this_time = frame_time
                while this_time >= next_frame_time_for_each_fps:
                    if frame_np is None:
                        frame_np = frame.to_ndarray(format="rgb24")

                    output_frames.append(frame_np)
                    next_frame_time_for_each_fps += time_delta_for_each_fps

    output_frames = np.stack(output_frames)

    vid_len_in_s = len(output_frames) / fps
    if max_length is not None and len(output_frames) > int(max_length * fps):
        output_frames = output_frames[: int(max_length * fps)]
        vid_len_in_s = max_length

    return output_frames, vid_len_in_s

@torch.inference_mode()
def encode_video_with_siglip2(x: torch.Tensor, model_dict, batch_size: int = -1):
    b, t, c, h, w = x.shape
    if batch_size < 0:
        batch_size = b * t
    x = rearrange(x, "b t c h w -> (b t) c h w")
    outputs = []
    for i in range(0, b * t, batch_size):
        outputs.append(model_dict.siglip2_model.get_image_features(pixel_values=x[i : i + batch_size]))
    res = torch.cat(outputs, dim=0)
    res = rearrange(res, "(b t) d -> b t d", b=b)
    return res

@torch.inference_mode()
def encode_video_with_sync(x: torch.Tensor, model_dict, batch_size: int = -1):
    """
    The input video of x is best to be in fps of 24 of greater than 24.
    Input:
        x: tensor in shape of [B, T, C, H, W]
        batch_size: the batch_size for synchformer inference
    """
    b, t, c, h, w = x.shape
    assert c == 3 and h == 224 and w == 224

    segment_size = 16
    step_size = 8
    num_segments = (t - segment_size) // step_size + 1
    segments = []
    for i in range(num_segments):
        segments.append(x[:, i * step_size : i * step_size + segment_size])
    x = torch.stack(segments, dim=1).cuda()  # (B, num_segments, segment_size, 3, 224, 224)

    outputs = []
    if batch_size < 0:
        batch_size = b * num_segments
    x = rearrange(x, "b s t c h w -> (b s) 1 t c h w")
    for i in range(0, b * num_segments, batch_size):
        with torch.autocast(device_type="cuda", enabled=True, dtype=torch.half):
            outputs.append(model_dict.syncformer_model(x[i : i + batch_size]))
    x = torch.cat(outputs, dim=0)  # [b * num_segments, 1, 8, 768]
    x = rearrange(x, "(b s) 1 t d -> b (s t) d", b=b)
    return x


@torch.inference_mode()
def encode_video_features(video_path, model_dict):
    visual_features = {}
    # siglip2 visual features
    frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["siglip2"])
    images = [Image.fromarray(frame).convert('RGB') for frame in frames]
    images = [model_dict.siglip2_preprocess(image) for image in images]  # [T, C, H, W]
    clip_frames = torch.stack(images).to(model_dict.device).unsqueeze(0)
    visual_features['siglip2_feat'] = encode_video_with_siglip2(clip_frames, model_dict).to(model_dict.device)

    # synchformer visual features
    frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["synchformer"])
    images = torch.from_numpy(frames).permute(0, 3, 1, 2)  # [T, C, H, W]
    sync_frames = model_dict.syncformer_preprocess(images).unsqueeze(0)  # [1, T, 3, 224, 224]
    # [1, num_segments * 8, channel_dim], e.g. [1, 240, 768] for 10s video
    visual_features['syncformer_feat'] = encode_video_with_sync(sync_frames, model_dict)

    vid_len_in_s = sync_frames.shape[1] / FPS_VISUAL["synchformer"]
    visual_features = AttributeDict(visual_features)

    return visual_features, vid_len_in_s

@torch.inference_mode()
def encode_text_feat(text: List[str], model_dict):
    # x: (B, L)
    inputs = model_dict.clap_tokenizer(text, padding=True, return_tensors="pt").to(model_dict.device)
    outputs = model_dict.clap_model(**inputs, output_hidden_states=True, return_dict=True)
    return outputs.last_hidden_state, outputs.attentions


def feature_process(video_path, prompt, model_dict, cfg):
    visual_feats, audio_len_in_s = encode_video_features(video_path, model_dict)
    neg_prompt = "noisy, harsh"
    prompts = [neg_prompt, prompt]
    text_feat_res, text_feat_mask = encode_text_feat(prompts, model_dict)

    text_feat = text_feat_res[1:]
    uncond_text_feat = text_feat_res[:1]

    if cfg.model_config.model_kwargs.text_length < text_feat.shape[1]:
        text_seq_length = cfg.model_config.model_kwargs.text_length
        text_feat = text_feat[:, :text_seq_length]
        uncond_text_feat = uncond_text_feat[:, :text_seq_length]

    text_feats = AttributeDict({
        'text_feat': text_feat,
        'uncond_text_feat': uncond_text_feat,
    })

    return visual_feats, text_feats, audio_len_in_s