{
"cells": [
{
"cell_type": "markdown",
"id": "136b43b6",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"We need `transformers`, `torchvision` and `einops` as basic dependencies for the model. \n",
"For this example, we also use `wget` for fetching data remotely, `decord` for decoding video frames, and `mediapy` for saving videos."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "4363e953",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: transformers in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (4.51.3)\n",
"Requirement already satisfied: torchvision in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (0.22.0)\n",
"Requirement already satisfied: einops in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (0.8.1)\n",
"Requirement already satisfied: decord in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (0.6.0)\n",
"Requirement already satisfied: mediapy in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (1.2.4)\n",
"Requirement already satisfied: filelock in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (3.18.0)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (0.30.2)\n",
"Requirement already satisfied: numpy>=1.17 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (2.2.5)\n",
"Requirement already satisfied: packaging>=20.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (25.0)\n",
"Requirement already satisfied: pyyaml>=5.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (6.0.2)\n",
"Requirement already satisfied: regex!=2019.12.17 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (2024.11.6)\n",
"Requirement already satisfied: requests in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (2.32.3)\n",
"Requirement already satisfied: tokenizers<0.22,>=0.21 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (0.21.1)\n",
"Requirement already satisfied: safetensors>=0.4.3 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (0.5.3)\n",
"Requirement already satisfied: tqdm>=4.27 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from transformers) (4.67.1)\n",
"Requirement already satisfied: torch==2.7.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torchvision) (2.7.0)\n",
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torchvision) (11.2.1)\n",
"Requirement already satisfied: typing-extensions>=4.10.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (4.13.2)\n",
"Requirement already satisfied: sympy>=1.13.3 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (1.14.0)\n",
"Requirement already satisfied: networkx in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (3.4.2)\n",
"Requirement already satisfied: jinja2 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (3.1.6)\n",
"Requirement already satisfied: fsspec in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (2025.3.2)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.77)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.77)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.80)\n",
"Requirement already satisfied: nvidia-cudnn-cu12==9.5.1.17 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (9.5.1.17)\n",
"Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.4.1)\n",
"Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (11.3.0.4)\n",
"Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (10.3.7.77)\n",
"Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (11.7.1.2)\n",
"Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.5.4.2)\n",
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (0.6.3)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.26.2 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (2.26.2)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.77)\n",
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (12.6.85)\n",
"Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (1.11.1.6)\n",
"Requirement already satisfied: triton==3.3.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from torch==2.7.0->torchvision) (3.3.0)\n",
"Requirement already satisfied: setuptools>=40.8.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from triton==3.3.0->torch==2.7.0->torchvision) (75.8.0)\n",
"Requirement already satisfied: ipython in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from mediapy) (8.36.0)\n",
"Requirement already satisfied: matplotlib in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from mediapy) (3.10.3)\n",
"Requirement already satisfied: decorator in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (5.2.1)\n",
"Requirement already satisfied: exceptiongroup in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (1.2.2)\n",
"Requirement already satisfied: jedi>=0.16 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (0.19.2)\n",
"Requirement already satisfied: matplotlib-inline in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (0.1.7)\n",
"Requirement already satisfied: pexpect>4.3 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (4.9.0)\n",
"Requirement already satisfied: prompt_toolkit<3.1.0,>=3.0.41 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (3.0.51)\n",
"Requirement already satisfied: pygments>=2.4.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (2.19.1)\n",
"Requirement already satisfied: stack_data in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (0.6.3)\n",
"Requirement already satisfied: traitlets>=5.13.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from ipython->mediapy) (5.14.3)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (1.3.2)\n",
"Requirement already satisfied: cycler>=0.10 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (0.12.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (4.58.0)\n",
"Requirement already satisfied: kiwisolver>=1.3.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (1.4.8)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (3.2.3)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from matplotlib->mediapy) (2.9.0.post0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from requests->transformers) (3.4.1)\n",
"Requirement already satisfied: idna<4,>=2.5 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from requests->transformers) (3.10)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from requests->transformers) (2.4.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from requests->transformers) (2025.4.26)\n",
"Requirement already satisfied: parso<0.9.0,>=0.8.4 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from jedi>=0.16->ipython->mediapy) (0.8.4)\n",
"Requirement already satisfied: ptyprocess>=0.5 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from pexpect>4.3->ipython->mediapy) (0.7.0)\n",
"Requirement already satisfied: wcwidth in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from prompt_toolkit<3.1.0,>=3.0.41->ipython->mediapy) (0.2.13)\n",
"Requirement already satisfied: six>=1.5 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib->mediapy) (1.17.0)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from sympy>=1.13.3->torch==2.7.0->torchvision) (1.3.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from jinja2->torch==2.7.0->torchvision) (3.0.2)\n",
"Requirement already satisfied: executing>=1.2.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from stack_data->ipython->mediapy) (2.2.0)\n",
"Requirement already satisfied: asttokens>=2.1.0 in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from stack_data->ipython->mediapy) (3.0.0)\n",
"Requirement already satisfied: pure_eval in /data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages (from stack_data->ipython->mediapy) (0.2.3)\n"
]
}
],
"source": [
"!pip install transformers torchvision einops decord mediapy"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "54c2ac81-3389-4c8d-bc08-4834eb88fa73",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/data/miniconda3/envs/cosmos-embed1/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import decord\n",
"import numpy as np\n",
"import torch\n",
"from transformers import AutoConfig, AutoModel, AutoProcessor\n",
"from IPython.display import Video\n",
"import subprocess\n",
"import io"
]
},
{
"cell_type": "markdown",
"id": "fa84e4fa",
"metadata": {},
"source": [
"## Instantiate model\n",
"\n",
"We use `AutoModel` and `AutoProcessor` to download the weights and inference code for Cosmos-Embed1. The model has been trained with bfloat16, so we should cast if the GPU supports it. The preprocessor tokenizes text and resizes/rescales batched video frames. We also override the default resolution to a non-square example."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7438262f-f1dc-4f33-a941-a40d4e43cda6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00, 15.64it/s]\n"
]
}
],
"source": [
"path = \"../\"\n",
"\n",
"config = AutoConfig.from_pretrained(path, trust_remote_code=True)\n",
"\n",
"model = AutoModel.from_pretrained(path, trust_remote_code=True, config=config).to(\"cuda\", dtype=torch.bfloat16)\n",
"model.eval()\n",
"preprocess = AutoProcessor.from_pretrained(path, trust_remote_code=True)"
]
},
{
"cell_type": "markdown",
"id": "bb9065d6",
"metadata": {},
"source": [
"## Fetch data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6d2287cf-badb-4608-9b4c-701c08e8217f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"--2025-06-03 16:11:10-- https://upload.wikimedia.org/wikipedia/commons/3/3d/Branko_Paukovic%2C_javelin_throw.webm\n",
"Resolving upload.wikimedia.org (upload.wikimedia.org)... 198.35.26.112, 2620:0:863:ed1a::2:b\n",
"Connecting to upload.wikimedia.org (upload.wikimedia.org)|198.35.26.112|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 159119 (155K) [video/webm]\n",
"Saving to: ‘/tmp/output.mp4’\n",
"\n",
" 0K .......... .......... .......... .......... .......... 32% 1.36M 0s\n",
" 50K .......... .......... .......... .......... .......... 64% 14.6M 0s\n",
" 100K .......... .......... .......... .......... .......... 96% 1.31M 0s\n",
" 150K ..... 100% 10.0T=0.08s\n",
"\n",
"2025-06-03 16:11:10 (1.98 MB/s) - ‘/tmp/output.mp4’ saved [159119/159119]\n",
"\n"
]
},
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"video_url = \"https://upload.wikimedia.org/wikipedia/commons/3/3d/Branko_Paukovic%2C_javelin_throw.webm\"\n",
"subprocess.check_call([\"wget\", \"-O\", \"/tmp/output.mp4\", video_url])\n",
"video_bytes = open(\"/tmp/output.mp4\", \"rb\").read()\n",
"assert video_bytes\n",
"Video(video_url)"
]
},
{
"cell_type": "markdown",
"id": "13ce12db",
"metadata": {},
"source": [
"We sample 8 frames from the single video and create a tensor of shape `batch_size x num_frames x channel_dim x height x width`. The model has been trained on 8 frames sampled at 1-2FPS. For this example, we linearly sample frames from the entire ~2s clip."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b57ed50d-f11b-4100-9a7d-45edc27babf9",
"metadata": {},
"outputs": [],
"source": [
"with io.BytesIO(video_bytes) as fp:\n",
" reader = decord.VideoReader(fp)\n",
" frame_ids = np.linspace(0, len(reader)-1, 8, dtype=int).tolist()\n",
" frames = reader.get_batch(frame_ids).asnumpy()\n",
"batch = np.transpose(np.expand_dims(frames, 0), (0, 1, 4, 2, 3)) # BTCHW"
]
},
{
"cell_type": "markdown",
"id": "8627495d",
"metadata": {},
"source": [
"## Inference"
]
},
{
"cell_type": "markdown",
"id": "4fccb879",
"metadata": {},
"source": [
"We run inference on the video batch by preprocessing it, moving it to the GPU and calling the `get_video_embeddings` method.\n",
"\n",
"We run inference on text captions by preprocessing them into tokens and attention masks, moving to the GPU and calling the `get_text_embeddings` method. \n",
"\n",
"We can then calculate the similarity between the text and video embeddings using a dot-product, and rank the captions by highest similarity to the video. The model correctly ranks the most likely caption as being `a man wearing red spandex throwing a javelin`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "376a6e0a-1932-4309-aa6f-0be92f2e5846",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a man wearing red spandex throwing a javelin\n"
]
}
],
"source": [
"video_inputs = preprocess(videos=batch).to(\"cuda\", dtype=torch.bfloat16)\n",
"with torch.no_grad():\n",
" video_out = model.get_video_embeddings(**video_inputs)\n",
"\n",
"captions = [\n",
" \"a person riding a motorcycle in the night\",\n",
" \"a car overtaking a white truck\",\n",
" \"a video of a knight fighting with a sword\",\n",
" \"a man wearing red spandex throwing a javelin\",\n",
" \"a young man javelin throwing during the evening\", # distractor\n",
" \"a man throwing a javelin with both hands\", # distractor\n",
"]\n",
"text_inputs = preprocess(text=captions).to(\"cuda\", dtype=torch.bfloat16)\n",
"with torch.no_grad():\n",
" text_out = model.get_text_embeddings(**text_inputs)\n",
"\n",
"probs = (torch.softmax(model.logit_scale.exp() * video_out.visual_proj @ text_out.text_proj.T, dim=-1))[0]\n",
"print(captions[probs.argmax()])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "cosmos-embed1",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}