Cosmos
Safetensors
NeMo
cosmos-embed1
nvidia
custom_code
Cosmos-Embed1-336p / examples /test_smoke.py
fferroni's picture
First commit
ecf8cbe
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Smoke tests for Cosmos-Embed1 including Transformer Engine support."""
# Determine test model path - can be overridden via environment variable or use current directory
import os
import pytest
import torch
from transformers import AutoConfig, AutoModel, AutoProcessor, AutoTokenizer
MODEL_PATH = os.environ.get("COSMOS_EMBED1_MODEL_PATH", ".")
def test_smoke() -> None:
"""Original smoke test for basic functionality."""
preprocess = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).to("cuda", dtype=torch.bfloat16)
with torch.no_grad():
text_inputs = preprocess(text=["a cat", "a dog"]).to("cuda", dtype=torch.bfloat16)
text_out = model.get_text_embeddings(**text_inputs)
assert text_out.text_proj.shape == (2, 768)
video_inputs = preprocess(videos=torch.randint(0, 255, size=(2, 8, 3, 224, 224))).to(
"cuda", dtype=torch.bfloat16
)
video_out = model.get_video_embeddings(**video_inputs)
assert video_out.visual_proj.shape == (2, 768)
text_video_inputs = preprocess(
text=["a cat", "a dog"],
videos=torch.randint(0, 255, size=(2, 8, 3, 448, 448)),
).to("cuda", dtype=torch.bfloat16)
text_video_out = model(**text_video_inputs)
assert text_video_out.text_proj.shape == text_video_out.visual_proj.shape == (2, 768)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
assert len(tokenizer) == 30523
# Clean up GPU memory after test
del model
torch.cuda.empty_cache()
def test_transformer_engine_available():
"""Test if Transformer Engine is available."""
try:
import transformer_engine.pytorch as te
# If we get here, TE is available
assert True
except ImportError:
pytest.skip("Transformer Engine not available, skipping TE tests")
def test_load_standard_model():
"""Test loading the standard (non-TE) model."""
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16)
assert model.transformer_engine == False
assert hasattr(model, "visual_encoder")
assert hasattr(model, "qformer")
# Clean up
del model
torch.cuda.empty_cache()
def test_load_transformer_engine_model():
"""Test loading model with Transformer Engine enabled."""
try:
import transformer_engine.pytorch as te
except ImportError:
pytest.skip("Transformer Engine not available, skipping TE tests")
# Load config and enable transformer engine
config = AutoConfig.from_pretrained(MODEL_PATH)
config.transformer_engine = True
config.use_fp8 = False # Start with FP8 disabled for basic test
model = AutoModel.from_pretrained(MODEL_PATH, config=config, trust_remote_code=True, torch_dtype=torch.bfloat16)
assert model.transformer_engine == True
assert model.use_fp8 == False
assert hasattr(model, "visual_encoder")
assert hasattr(model, "qformer")
# Clean up
del model
torch.cuda.empty_cache()
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available, skipping GPU test")
def test_transformer_engine_inference():
"""Test inference with Transformer Engine model."""
try:
import transformer_engine.pytorch as te
except ImportError:
pytest.skip("Transformer Engine not available, skipping TE tests")
# Test text embeddings with standard model first
preprocess = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
text_inputs = preprocess(text=["a cat"]).to("cuda", dtype=torch.bfloat16)
# Load standard model, run inference, then clean up
model_standard = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16).to(
"cuda"
)
with torch.no_grad():
text_out_std = model_standard.get_text_embeddings(**text_inputs)
# Clean up standard model before loading TE model
del model_standard
torch.cuda.empty_cache()
# Now load TE model and run inference
config = AutoConfig.from_pretrained(MODEL_PATH)
config.transformer_engine = True
config.use_fp8 = False
model_te = AutoModel.from_pretrained(
MODEL_PATH, config=config, trust_remote_code=True, torch_dtype=torch.bfloat16
).to("cuda")
with torch.no_grad():
text_out_te = model_te.get_text_embeddings(**text_inputs)
# Check shapes match
assert text_out_std.text_proj.shape == text_out_te.text_proj.shape
assert text_out_std.text_proj.shape == (1, 768)
# Clean up GPU memory
del model_te
torch.cuda.empty_cache()
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available, skipping GPU test")
def test_transformer_engine_fp8():
"""Test loading model with Transformer Engine + FP8 (requires substantial GPU memory)."""
try:
import transformer_engine.pytorch as te
except ImportError:
pytest.skip("Transformer Engine not available, skipping FP8 tests")
# Clear memory before this memory-intensive test
torch.cuda.empty_cache()
config = AutoConfig.from_pretrained(MODEL_PATH)
config.transformer_engine = True
config.use_fp8 = True
model = AutoModel.from_pretrained(MODEL_PATH, config=config, trust_remote_code=True, torch_dtype=torch.bfloat16)
assert model.transformer_engine == True
assert model.use_fp8 == True
# Clean up
del model
torch.cuda.empty_cache()
def test_transformer_engine_config_validation():
"""Test configuration validation for Transformer Engine."""
# Test that use_fp8=True without transformer_engine=True should fail
config = AutoConfig.from_pretrained(MODEL_PATH)
config.transformer_engine = False
config.use_fp8 = True
with pytest.raises(ValueError, match="transformer_engine.*must be enabled.*use_fp8"):
from modeling_vit import VisionTransformer
VisionTransformer(transformer_engine=False, use_fp8=True)