{ "cells": [ { "cell_type": "markdown", "id": "136b43b6", "metadata": {}, "source": [ "## Setup\n", "\n", "We need `transformers`, `torchvision` and `einops` as basic dependencies for the model. \n", "For this example, we also use `wget` for fetching data remotely, `decord` for decoding video frames, and `mediapy` for saving videos." ] }, { "cell_type": "code", "execution_count": 1, "id": "4363e953", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: transformers in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (4.51.3)\n", "Requirement already satisfied: torchvision in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (0.22.0)\n", "Requirement already satisfied: einops in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (0.8.1)\n", "Requirement already satisfied: decord in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (0.6.0)\n", "Requirement already satisfied: mediapy in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (1.2.4)\n", "Requirement already satisfied: filelock in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (3.18.0)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (0.30.2)\n", "Requirement already satisfied: numpy>=1.17 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (2.2.5)\n", "Requirement already satisfied: packaging>=20.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (25.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (6.0.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (2024.11.6)\n", "Requirement already satisfied: requests in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (2.32.3)\n", "Requirement already satisfied: tokenizers<0.22,>=0.21 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (0.21.1)\n", "Requirement already satisfied: safetensors>=0.4.3 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (0.5.3)\n", "Requirement already satisfied: tqdm>=4.27 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (4.67.1)\n", "Requirement already satisfied: torch==2.7.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torchvision) (2.7.0)\n", "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torchvision) (11.2.1)\n", "Requirement already satisfied: typing-extensions>=4.10.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (4.13.2)\n", "Requirement already satisfied: sympy>=1.13.3 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (1.14.0)\n", "Requirement already satisfied: networkx in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (3.4.2)\n", "Requirement already satisfied: jinja2 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (3.1.6)\n", "Requirement already satisfied: fsspec in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (2025.3.2)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.77)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.77)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.80)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.5.1.17 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (9.5.1.17)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.4.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (11.3.0.4)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (10.3.7.77)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (11.7.1.2)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.5.4.2)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (0.6.3)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.26.2 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (2.26.2)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.77)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.85)\n", "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (1.11.1.6)\n", "Requirement already satisfied: triton==3.3.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (3.3.0)\n", "Requirement already satisfied: setuptools>=40.8.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from triton==3.3.0->torch==2.7.0->torchvision) (75.8.0)\n", "Requirement already satisfied: ipython in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from mediapy) (8.36.0)\n", "Requirement already satisfied: matplotlib in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from mediapy) (3.10.3)\n", "Requirement already satisfied: decorator in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (5.2.1)\n", "Requirement already satisfied: exceptiongroup in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (1.2.2)\n", "Requirement already satisfied: jedi>=0.16 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (0.19.2)\n", "Requirement already satisfied: matplotlib-inline in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (0.1.7)\n", "Requirement already satisfied: pexpect>4.3 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (4.9.0)\n", "Requirement already satisfied: prompt_toolkit<3.1.0,>=3.0.41 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (3.0.51)\n", "Requirement already satisfied: pygments>=2.4.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (2.19.1)\n", "Requirement already satisfied: stack_data in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (0.6.3)\n", "Requirement already satisfied: traitlets>=5.13.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (5.14.3)\n", "Requirement already satisfied: contourpy>=1.0.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (1.3.2)\n", "Requirement already satisfied: cycler>=0.10 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (0.12.1)\n", "Requirement already satisfied: fonttools>=4.22.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (4.58.0)\n", "Requirement already satisfied: kiwisolver>=1.3.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (1.4.8)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (3.2.3)\n", "Requirement already satisfied: python-dateutil>=2.7 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (2.9.0.post0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from requests->transformers) (3.4.1)\n", "Requirement already satisfied: idna<4,>=2.5 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from requests->transformers) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from requests->transformers) (2.4.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from requests->transformers) (2025.4.26)\n", "Requirement already satisfied: parso<0.9.0,>=0.8.4 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from jedi>=0.16->ipython->mediapy) (0.8.4)\n", "Requirement already satisfied: ptyprocess>=0.5 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from pexpect>4.3->ipython->mediapy) (0.7.0)\n", "Requirement already satisfied: wcwidth in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from prompt_toolkit<3.1.0,>=3.0.41->ipython->mediapy) (0.2.13)\n", "Requirement already satisfied: six>=1.5 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib->mediapy) (1.17.0)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from sympy>=1.13.3->torch==2.7.0->torchvision) (1.3.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from jinja2->torch==2.7.0->torchvision) (3.0.2)\n", "Requirement already satisfied: executing>=1.2.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from stack_data->ipython->mediapy) (2.2.0)\n", "Requirement already satisfied: asttokens>=2.1.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from stack_data->ipython->mediapy) (3.0.0)\n", "Requirement already satisfied: pure_eval in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from stack_data->ipython->mediapy) (0.2.3)\n" ] } ], "source": [ "!pip install transformers torchvision einops decord mediapy" ] }, { "cell_type": "code", "execution_count": 1, "id": "54c2ac81-3389-4c8d-bc08-4834eb88fa73", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import decord\n", "import numpy as np\n", "import torch\n", "from transformers import AutoConfig, AutoModel, AutoProcessor\n", "from IPython.display import Video\n", "import subprocess\n", "import io" ] }, { "cell_type": "markdown", "id": "fa84e4fa", "metadata": {}, "source": [ "## Instantiate model\n", "\n", "We use `AutoModel` and `AutoProcessor` to download the weights and inference code for Cosmos-Embed1. The model has been trained with bfloat16, so we should cast if the GPU supports it. The preprocessor tokenizes text and resizes/rescales batched video frames. We also override the default resolution to a non-square example." ] }, { "cell_type": "code", "execution_count": 2, "id": "7438262f-f1dc-4f33-a941-a40d4e43cda6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00, 15.64it/s]\n" ] } ], "source": [ "path = \"../\"\n", "\n", "config = AutoConfig.from_pretrained(path, trust_remote_code=True)\n", "\n", "model = AutoModel.from_pretrained(path, trust_remote_code=True, config=config).to(\"cuda\", dtype=torch.bfloat16)\n", "model.eval()\n", "preprocess = AutoProcessor.from_pretrained(path, trust_remote_code=True)" ] }, { "cell_type": "markdown", "id": "bb9065d6", "metadata": {}, "source": [ "## Fetch data" ] }, { "cell_type": "code", "execution_count": 3, "id": "6d2287cf-badb-4608-9b4c-701c08e8217f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "--2025-06-03 16:11:10-- https://upload.wikimedia.org/wikipedia/commons/3/3d/Branko_Paukovic%2C_javelin_throw.webm\n", "Resolving upload.wikimedia.org (upload.wikimedia.org)... 198.35.26.112, 2620:0:863:ed1a::2:b\n", "Connecting to upload.wikimedia.org (upload.wikimedia.org)|198.35.26.112|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 159119 (155K) [video/webm]\n", "Saving to: ‘/tmp/output.mp4’\n", "\n", " 0K .......... .......... .......... .......... .......... 32% 1.36M 0s\n", " 50K .......... .......... .......... .......... .......... 64% 14.6M 0s\n", " 100K .......... .......... .......... .......... .......... 96% 1.31M 0s\n", " 150K ..... 100% 10.0T=0.08s\n", "\n", "2025-06-03 16:11:10 (1.98 MB/s) - ‘/tmp/output.mp4’ saved [159119/159119]\n", "\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "video_url = \"https://upload.wikimedia.org/wikipedia/commons/3/3d/Branko_Paukovic%2C_javelin_throw.webm\"\n", "subprocess.check_call([\"wget\", \"-O\", \"/tmp/output.mp4\", video_url])\n", "video_bytes = open(\"/tmp/output.mp4\", \"rb\").read()\n", "assert video_bytes\n", "Video(video_url)" ] }, { "cell_type": "markdown", "id": "13ce12db", "metadata": {}, "source": [ "We sample 8 frames from the single video and create a tensor of shape `batch_size x num_frames x channel_dim x height x width`. The model has been trained on 8 frames sampled at 1-2FPS. For this example, we linearly sample frames from the entire ~2s clip." ] }, { "cell_type": "code", "execution_count": 4, "id": "b57ed50d-f11b-4100-9a7d-45edc27babf9", "metadata": {}, "outputs": [], "source": [ "with io.BytesIO(video_bytes) as fp:\n", " reader = decord.VideoReader(fp)\n", " frame_ids = np.linspace(0, len(reader)-1, 8, dtype=int).tolist()\n", " frames = reader.get_batch(frame_ids).asnumpy()\n", "batch = np.transpose(np.expand_dims(frames, 0), (0, 1, 4, 2, 3)) # BTCHW" ] }, { "cell_type": "markdown", "id": "8627495d", "metadata": {}, "source": [ "## Inference" ] }, { "cell_type": "markdown", "id": "4fccb879", "metadata": {}, "source": [ "We run inference on the video batch by preprocessing it, moving it to the GPU and calling the `get_video_embeddings` method.\n", "\n", "We run inference on text captions by preprocessing them into tokens and attention masks, moving to the GPU and calling the `get_text_embeddings` method. \n", "\n", "We can then calculate the similarity between the text and video embeddings using a dot-product, and rank the captions by highest similarity to the video. The model correctly ranks the most likely caption as being `a man wearing red spandex throwing a javelin`." ] }, { "cell_type": "code", "execution_count": 5, "id": "376a6e0a-1932-4309-aa6f-0be92f2e5846", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a man wearing red spandex throwing a javelin\n" ] } ], "source": [ "video_inputs = preprocess(videos=batch).to(\"cuda\", dtype=torch.bfloat16)\n", "with torch.no_grad():\n", " video_out = model.get_video_embeddings(**video_inputs)\n", "\n", "captions = [\n", " \"a person riding a motorcycle in the night\",\n", " \"a car overtaking a white truck\",\n", " \"a video of a knight fighting with a sword\",\n", " \"a man wearing red spandex throwing a javelin\",\n", " \"a young man javelin throwing during the evening\", # distractor\n", " \"a man throwing a javelin with both hands\", # distractor\n", "]\n", "text_inputs = preprocess(text=captions).to(\"cuda\", dtype=torch.bfloat16)\n", "with torch.no_grad():\n", " text_out = model.get_text_embeddings(**text_inputs)\n", "\n", "probs = (torch.softmax(model.logit_scale.exp() * video_out.visual_proj @ text_out.text_proj.T, dim=-1))[0]\n", "print(captions[probs.argmax()])" ] } ], "metadata": { "kernelspec": { "display_name": "cosmos-embed1", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }