|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Processor class for Cosmos-Embed1 |
|
""" |
|
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torchvision |
|
from transformers import AutoProcessor, BatchFeature |
|
from transformers.processing_utils import ProcessorMixin |
|
from transformers.utils import TensorType |
|
|
|
from .configuration_embed1 import CosmosEmbed1Config |
|
|
|
|
|
class CosmosEmbed1Processor(ProcessorMixin): |
|
r""" |
|
Constructs a processor which wraps a BertTokenizer tokenizer and a fast video resize function. |
|
|
|
Args: |
|
tokenizer ([`BertTokenizerFast`], *optional*): |
|
The tokenizer is a required input for text processing. |
|
config ([`CosmosEmbed1Config`], *optional*): |
|
Needed for processing options. |
|
""" |
|
|
|
attributes = ["tokenizer"] |
|
tokenizer_class = ("BertTokenizer", "BertTokenizerFast") |
|
config_class = CosmosEmbed1Config |
|
chat_template = None |
|
|
|
def __init__( |
|
self, |
|
tokenizer=None, |
|
resolution: Union[int, Tuple[int, int]] = 336, |
|
num_video_frames: int = 8, |
|
max_txt_len: int = 128, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(tokenizer, **kwargs) |
|
self.resolution = resolution |
|
self.num_video_frames = num_video_frames |
|
self.max_txt_len = max_txt_len |
|
|
|
def __call__( |
|
self, |
|
text: Optional[Union[str, List[str]]] = None, |
|
videos: Optional[Union[np.ndarray, torch.Tensor]] = None, |
|
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, |
|
resolution: Union[int, Tuple[int, int]] = None, |
|
num_video_frames: int = None, |
|
max_txt_len: int = None, |
|
**kwargs, |
|
) -> BatchFeature: |
|
inputs = {} |
|
|
|
if text is not None: |
|
max_txt_len = max_txt_len if max_txt_len is not None else self.max_txt_len |
|
tokenized = self.tokenizer( |
|
text, return_tensors="pt", padding="max_length", truncation=True, max_length=max_txt_len, **kwargs |
|
) |
|
inputs["input_ids"] = tokenized.input_ids |
|
inputs["attention_mask"] = tokenized.attention_mask.float() |
|
|
|
if videos is not None: |
|
if isinstance(videos, np.ndarray): |
|
videos = torch.from_numpy(videos) |
|
if not isinstance(videos, torch.Tensor) or videos.ndim != 5: |
|
raise ValueError("Processor expects a numpy or torch tensor of shape BTCHW from [0-255].") |
|
resolution = resolution if resolution is not None else self.resolution |
|
if isinstance(resolution, int): |
|
resolution = (resolution, resolution) |
|
_, t, c, h, w = videos.shape |
|
if c != 3: |
|
raise ValueError(f"Expected tensor of shape BTCHW with RGB channels, got channel size {c}.") |
|
num_video_frames = num_video_frames if num_video_frames is not None else self.num_video_frames |
|
if t != num_video_frames: |
|
raise ValueError(f"Expected tensor of shape BTCHW with {num_video_frames} frames, got {t}.") |
|
if h != resolution[0] or w != resolution[1]: |
|
videos = resize_video(videos, resolution) |
|
if videos.dtype == torch.uint8: |
|
videos = videos.float() |
|
inputs["videos"] = videos / 255.0 |
|
|
|
if not inputs: |
|
raise ValueError("Must pass either `text` or `videos` argument to __call__ function.") |
|
|
|
return BatchFeature(inputs, tensor_type=return_tensors) |
|
|
|
|
|
def resize_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: |
|
"""Resize a video tensor (B, T, C, H, W) to a new height/width. |
|
|
|
Args: |
|
video (torch.Tensor): (B, T, C, H, W) uint8 or float32. |
|
size (tuple): target (H', W') size. |
|
Returns: |
|
torch.Tensor: resized video of shape (B, T, C, H', W') |
|
""" |
|
h, w = size |
|
B, T, C, H, W = video.shape |
|
video = video.view(B * T, C, H, W) |
|
resize = torchvision.transforms.Resize( |
|
(h, w), |
|
antialias=True, |
|
interpolation=torchvision.transforms.InterpolationMode.BILINEAR, |
|
) |
|
video = resize(video) |
|
new_H, new_W = video.shape[-2:] |
|
video = video.view(B, T, C, new_H, new_W) |
|
return video |
|
|
|
|
|
AutoProcessor.register(CosmosEmbed1Config, CosmosEmbed1Processor) |
|
|
|
|
|
__all__ = ["CosmosEmbed1Processor"] |
|
|