{ "cells": [ { "cell_type": "markdown", "id": "136b43b6", "metadata": {}, "source": [ "## Setup\n", "\n", "We need `transformers`, `torchvision` and `einops` as basic dependencies for the model. \n", "For this example, we also use `wget` for fetching data remotely, `decord` for decoding video frames, and `mediapy` for saving videos." ] }, { "cell_type": "code", "execution_count": 1, "id": "4363e953", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: transformers in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (4.51.3)\n", "Requirement already satisfied: torchvision in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (0.22.0)\n", "Requirement already satisfied: einops in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (0.8.1)\n", "Requirement already satisfied: decord in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (0.6.0)\n", "Requirement already satisfied: mediapy in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (1.2.4)\n", "Requirement already satisfied: filelock in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (3.18.0)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (0.30.2)\n", "Requirement already satisfied: numpy>=1.17 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (2.2.5)\n", "Requirement already satisfied: packaging>=20.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (25.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (6.0.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (2024.11.6)\n", "Requirement already satisfied: requests in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (2.32.3)\n", "Requirement already satisfied: tokenizers<0.22,>=0.21 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (0.21.1)\n", "Requirement already satisfied: safetensors>=0.4.3 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (0.5.3)\n", "Requirement already satisfied: tqdm>=4.27 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (4.67.1)\n", "Requirement already satisfied: torch==2.7.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torchvision) (2.7.0)\n", "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torchvision) (11.2.1)\n", "Requirement already satisfied: typing-extensions>=4.10.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (4.13.2)\n", "Requirement already satisfied: sympy>=1.13.3 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (1.14.0)\n", "Requirement already satisfied: networkx in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (3.4.2)\n", "Requirement already satisfied: jinja2 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (3.1.6)\n", "Requirement already satisfied: fsspec in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (2025.3.2)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.77)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.77)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.80)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.5.1.17 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (9.5.1.17)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.4.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (11.3.0.4)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (10.3.7.77)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (11.7.1.2)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.5.4.2)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (0.6.3)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.26.2 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (2.26.2)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.77)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.85)\n", "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (1.11.1.6)\n", "Requirement already satisfied: triton==3.3.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (3.3.0)\n", "Requirement already satisfied: setuptools>=40.8.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from triton==3.3.0->torch==2.7.0->torchvision) (75.8.0)\n", "Requirement already satisfied: ipython in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from mediapy) (8.36.0)\n", "Requirement already satisfied: matplotlib in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from mediapy) (3.10.3)\n", "Requirement already satisfied: decorator in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (5.2.1)\n", "Requirement already satisfied: exceptiongroup in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (1.2.2)\n", "Requirement already satisfied: jedi>=0.16 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (0.19.2)\n", "Requirement already satisfied: matplotlib-inline in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (0.1.7)\n", "Requirement already satisfied: pexpect>4.3 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (4.9.0)\n", "Requirement already satisfied: prompt_toolkit<3.1.0,>=3.0.41 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (3.0.51)\n", "Requirement already satisfied: pygments>=2.4.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (2.19.1)\n", "Requirement already satisfied: stack_data in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (0.6.3)\n", "Requirement already satisfied: traitlets>=5.13.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (5.14.3)\n", "Requirement already satisfied: contourpy>=1.0.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (1.3.2)\n", "Requirement already satisfied: cycler>=0.10 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (0.12.1)\n", "Requirement already satisfied: fonttools>=4.22.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (4.58.0)\n", "Requirement already satisfied: kiwisolver>=1.3.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (1.4.8)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (3.2.3)\n", "Requirement already satisfied: python-dateutil>=2.7 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (2.9.0.post0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from requests->transformers) (3.4.1)\n", "Requirement already satisfied: idna<4,>=2.5 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from requests->transformers) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from requests->transformers) (2.4.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from requests->transformers) (2025.4.26)\n", "Requirement already satisfied: parso<0.9.0,>=0.8.4 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from jedi>=0.16->ipython->mediapy) (0.8.4)\n", "Requirement already satisfied: ptyprocess>=0.5 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from pexpect>4.3->ipython->mediapy) (0.7.0)\n", "Requirement already satisfied: wcwidth in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from prompt_toolkit<3.1.0,>=3.0.41->ipython->mediapy) (0.2.13)\n", "Requirement already satisfied: six>=1.5 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib->mediapy) (1.17.0)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from sympy>=1.13.3->torch==2.7.0->torchvision) (1.3.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from jinja2->torch==2.7.0->torchvision) (3.0.2)\n", "Requirement already satisfied: executing>=1.2.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from stack_data->ipython->mediapy) (2.2.0)\n", "Requirement already satisfied: asttokens>=2.1.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from stack_data->ipython->mediapy) (3.0.0)\n", "Requirement already satisfied: pure_eval in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from stack_data->ipython->mediapy) (0.2.3)\n" ] } ], "source": [ "!pip install transformers torchvision einops decord mediapy" ] }, { "cell_type": "code", "execution_count": 2, "id": "54c2ac81-3389-4c8d-bc08-4834eb88fa73", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import decord\n", "import numpy as np\n", "import torch\n", "from transformers import AutoConfig, AutoModel, AutoProcessor\n", "from IPython.display import Video\n", "import subprocess\n", "import io" ] }, { "cell_type": "markdown", "id": "fa84e4fa", "metadata": {}, "source": [ "## Instantiate model\n", "\n", "We use `AutoModel` and `AutoProcessor` to download the weights and inference code for Cosmos-Embed1. The model has been trained with bfloat16, so we should cast if the GPU supports it. The preprocessor tokenizes text and resizes/rescales batched video frames. We also override the default resolution to a non-square example." ] }, { "cell_type": "code", "execution_count": 3, "id": "7438262f-f1dc-4f33-a941-a40d4e43cda6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00, 14.57it/s]\n" ] } ], "source": [ "path = \"nvidia/Cosmos-Embed1-336p\"\n", "resolution_override = (588, 672) # ideally divisible by patch size 14\n", "\n", "config = AutoConfig.from_pretrained(path, trust_remote_code=True)\n", "config.resolution = resolution_override\n", "\n", "model = AutoModel.from_pretrained(path, trust_remote_code=True, config=config).to(\"cuda\", dtype=torch.bfloat16)\n", "model.eval()\n", "preprocess = AutoProcessor.from_pretrained(path, resolution=resolution_override, trust_remote_code=True)" ] }, { "cell_type": "markdown", "id": "bb9065d6", "metadata": {}, "source": [ "## Fetch data" ] }, { "cell_type": "code", "execution_count": 4, "id": "6d2287cf-badb-4608-9b4c-701c08e8217f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "--2025-05-28 20:44:26-- https://upload.wikimedia.org/wikipedia/commons/3/3d/Branko_Paukovic%2C_javelin_throw.webm\n", "Resolving upload.wikimedia.org (upload.wikimedia.org)... 198.35.26.112, 2620:0:861:ed1a::2:b\n", "Connecting to upload.wikimedia.org (upload.wikimedia.org)|198.35.26.112|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 159119 (155K) [video/webm]\n", "Saving to: ‘/tmp/output.mp4’\n", "\n", " 0K .......... .......... .......... .......... .......... 32% 6.95M 0s\n", " 50K .......... .......... .......... .......... .......... 64% 10.8M 0s\n", " 100K .......... .......... .......... .......... .......... 96% 27.0M 0s\n", " 150K ..... 100% 10.0T=0.01s\n", "\n", "2025-05-28 20:44:26 (11.4 MB/s) - ‘/tmp/output.mp4’ saved [159119/159119]\n", "\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "video_url = \"https://upload.wikimedia.org/wikipedia/commons/3/3d/Branko_Paukovic%2C_javelin_throw.webm\"\n", "subprocess.check_call([\"wget\", \"-O\", \"/tmp/output.mp4\", video_url])\n", "video_bytes = open(\"/tmp/output.mp4\", \"rb\").read()\n", "assert video_bytes\n", "Video(video_url)" ] }, { "cell_type": "markdown", "id": "13ce12db", "metadata": {}, "source": [ "We sample 8 frames from the single video and create a tensor of shape `batch_size x num_frames x channel_dim x height x width`. The model has been trained on 8 frames sampled at 1-2FPS. For this example, we linearly sample frames from the entire ~2s clip." ] }, { "cell_type": "code", "execution_count": 5, "id": "b57ed50d-f11b-4100-9a7d-45edc27babf9", "metadata": {}, "outputs": [], "source": [ "with io.BytesIO(video_bytes) as fp:\n", " reader = decord.VideoReader(fp)\n", " frame_ids = np.linspace(0, len(reader)-1, 8, dtype=int).tolist()\n", " frames = reader.get_batch(frame_ids).asnumpy()\n", "batch = np.transpose(np.expand_dims(frames, 0), (0, 1, 4, 2, 3)) # BTCHW" ] }, { "cell_type": "markdown", "id": "8627495d", "metadata": {}, "source": [ "## Inference" ] }, { "cell_type": "markdown", "id": "4fccb879", "metadata": {}, "source": [ "We run inference on the video batch by preprocessing it, moving it to the GPU and calling the `get_video_embeddings` method.\n", "\n", "We run inference on text captions by preprocessing them into tokens and attention masks, moving to the GPU and calling the `get_text_embeddings` method. \n", "\n", "We can then calculate the similarity between the text and video embeddings using a dot-product, and rank the captions by highest similarity to the video. The model correctly ranks the most likely caption as being `a man wearing red spandex throwing a javelin`." ] }, { "cell_type": "code", "execution_count": 6, "id": "376a6e0a-1932-4309-aa6f-0be92f2e5846", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a man wearing red spandex throwing a javelin\n" ] } ], "source": [ "video_inputs = preprocess(videos=batch).to(\"cuda\", dtype=torch.bfloat16)\n", "with torch.no_grad():\n", " video_out = model.get_video_embeddings(**video_inputs)\n", "\n", "captions = [\n", " \"a person riding a motorcycle in the night\",\n", " \"a car overtaking a white truck\",\n", " \"a video of a knight fighting with a sword\",\n", " \"a man wearing red spandex throwing a javelin\",\n", " \"a young man javelin throwing during the evening\", # distractor\n", " \"a man throwing a javelin with both hands\", # distractor\n", "]\n", "text_inputs = preprocess(text=captions).to(\"cuda\", dtype=torch.bfloat16)\n", "with torch.no_grad():\n", " text_out = model.get_text_embeddings(**text_inputs)\n", "\n", "probs = (torch.softmax(model.logit_scale.exp() * video_out.visual_proj @ text_out.text_proj.T, dim=-1))[0]\n", "print(captions[probs.argmax()])" ] }, { "cell_type": "markdown", "id": "077a35e7", "metadata": {}, "source": [ "## Intermediate feature maps" ] }, { "cell_type": "markdown", "id": "4dd9a66c", "metadata": {}, "source": [ "We can also display the intermediate per-frame dense feature maps, displaying temporal stability and separability." ] }, { "cell_type": "code", "execution_count": 7, "id": "2957ff9a", "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "def get_pca_map(\n", " feature_map: torch.Tensor,\n", " img_size,\n", " interpolation=\"bicubic\",\n", " return_pca_stats=False,\n", " pca_stats=None,\n", " skip_components: int = 0,\n", "):\n", " if feature_map.shape[0] != 1:\n", " feature_map = feature_map[None]\n", " if pca_stats is None:\n", " reduct_mat, color_min, color_max = get_robust_pca(\n", " feature_map.reshape(-1, feature_map.shape[-1]), skip=skip_components,\n", " )\n", " else:\n", " reduct_mat, color_min, color_max = pca_stats\n", " pca_color = feature_map @ reduct_mat\n", " pca_color = (pca_color - color_min) / (color_max - color_min)\n", " pca_color = pca_color.clamp(0, 1)\n", " pca_color = F.interpolate(\n", " pca_color.permute(0, 3, 1, 2),\n", " size=img_size,\n", " mode=interpolation,\n", " ).permute(0, 2, 3, 1)\n", " pca_color = pca_color.cpu().numpy().squeeze(0)\n", " if return_pca_stats:\n", " return pca_color, (reduct_mat, color_min, color_max)\n", " return pca_color\n", "\n", "\n", "def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False, skip: int = 0):\n", " assert len(features.shape) == 2, \"features should be (N, C)\"\n", " reduction_mat = torch.pca_lowrank(features, q=3 + skip, niter=20)[2]\n", " reduction_mat = reduction_mat[:, skip:]\n", " colors = features @ reduction_mat\n", " if remove_first_component:\n", " colors_min = colors.min(dim=0).values\n", " colors_max = colors.max(dim=0).values\n", " tmp_colors = (colors - colors_min) / (colors_max - colors_min)\n", " fg_mask = tmp_colors[..., 0] < 0.2\n", " reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2]\n", " colors = features @ reduction_mat\n", " else:\n", " fg_mask = torch.ones_like(colors[:, 0]).bool()\n", " d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values)\n", " mdev = torch.median(d, dim=0).values\n", " s = d / mdev\n", " try:\n", " rins = colors[fg_mask][s[:, 0] < m, 0]\n", " gins = colors[fg_mask][s[:, 1] < m, 1]\n", " bins = colors[fg_mask][s[:, 2] < m, 2]\n", " rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])\n", " rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])\n", " except:\n", " rins = colors\n", " gins = colors\n", " bins = colors\n", " rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])\n", " rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])\n", "\n", " return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat)" ] }, { "cell_type": "code", "execution_count": 8, "id": "30783420", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([64, 3, 360, 640])\n" ] } ], "source": [ "video = \"/tmp/output.mp4\"\n", "input_frames = np.stack([x.asnumpy() for x in decord.VideoReader(video)])\n", "input_frames = torch.from_numpy(np.transpose(input_frames, (0, 3, 1, 2)))\n", "print(input_frames.shape)" ] }, { "cell_type": "code", "execution_count": 9, "id": "37439827", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([64, 42, 48, 1408])\n" ] } ], "source": [ "from einops import rearrange\n", "\n", "num_frames = 8\n", "num_batches = len(input_frames) // num_frames\n", "batches = preprocess(videos=rearrange(input_frames, \"(b t) c h w -> b t c h w\", b=num_batches, t=num_frames)).to(\"cuda\", dtype=torch.bfloat16)\n", "with torch.no_grad():\n", " dense_features = torch.stack([\n", " model.get_video_embeddings(videos=inp).visual_embs[0]\n", " for inp in batches[\"videos\"]\n", " ])\n", "dense_features = rearrange(dense_features, \"b t h w c -> (b t) h w c\").to(\"cpu\", dtype=torch.float32)\n", "print(dense_features.shape)" ] }, { "cell_type": "markdown", "id": "cdeec855", "metadata": {}, "source": [ "Find PCA components of dense features for visualization purposes" ] }, { "cell_type": "code", "execution_count": 10, "id": "4779cb35", "metadata": {}, "outputs": [], "source": [ "num_keyframes = 30\n", "kf_stride = max(dense_features.shape[0] // num_keyframes, 1)\n", "sampled_features = dense_features[::kf_stride]\n", "pca_stats = get_robust_pca(sampled_features.flatten(0, 2))\n", "original_frames = input_frames.permute((0, 2, 3, 1)).cpu()\n", "\n", "output_frames = []\n", "for raw_frame, features in zip(original_frames, dense_features, strict=True):\n", " pca_features = get_pca_map(features, raw_frame.shape[0:2], pca_stats=pca_stats, interpolation=\"bilinear\")\n", " pca_features = np.floor(pca_features * 255.0).astype(np.uint8)\n", " pca_features = np.concatenate((raw_frame, pca_features), 1)\n", " output_frames.append(pca_features)\n", "output_frames = np.stack(output_frames)" ] }, { "cell_type": "code", "execution_count": 11, "id": "4ba37f71", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import mediapy\n", "\n", "viz_file = \"/tmp/output-visualization.mp4\"\n", "mediapy.write_video(viz_file, output_frames, fps=30, codec=\"libx264\")\n", "Video(viz_file, embed=True)" ] } ], "metadata": { "kernelspec": { "display_name": "cosmos-embed1", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }