# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Configuration class for Cosmos-Embed1.""" from typing import Any, Literal, Tuple, Union from transformers import AutoConfig, PretrainedConfig class CosmosEmbed1Config(PretrainedConfig): model_type = "cosmos-embed1" def __init__( self, embed_dim: int = 768, num_query_tokens: int = 32, max_txt_len: int = 128, num_video_frames: int = 8, temporal_encoding_type: Literal[ "neighboring_token_propagation", "temporal_parameter" ] = "neighboring_token_propagation", resolution: Union[int, Tuple[int, int]] = 224, vocab_size: int = 30523, transformer_engine: bool = False, use_fp8: bool = False, **kwargs: Any, ) -> None: """Configuration for `CosmosEmbed1Config`. Args: embed_dim (int): the dimension of extracted text-visual embeddings. num_query_tokens (int): number of learnable query tokens. max_txt_len (int): max length of text token sequences before truncation. num_video_frames (int): number of input video frames. temporal_encoding_type (str): temporal encoding module type. resolution (Union[int, Tuple[int, int]]): input video frame resolution. Can be an integer for square images (height=width) or a tuple of (height, width) for non-square. vocab_size (int): vocab size for text tokenizer. The default is from `bert-base-uncased` with an extra [DEC] token. transformer_engine (bool): whether to use TransformerEngine for acceleration. use_fp8 (bool): whether to use FP8 precision (requires transformer_engine=True). """ super().__init__(**kwargs) self.embed_dim = embed_dim self.num_query_tokens = num_query_tokens self.max_txt_len = max_txt_len self.num_video_frames = num_video_frames self.temporal_encoding_type = temporal_encoding_type self.resolution = resolution self.vocab_size = vocab_size self.transformer_engine = transformer_engine self.use_fp8 = use_fp8 AutoConfig.register("cosmos-embed1", CosmosEmbed1Config)