Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import os | |
import time | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional, Set | |
import torch | |
from megatron.core import ModelParallelConfig, parallel_state | |
from safetensors.torch import load_file | |
from torch.nn.modules.module import _IncompatibleKeys | |
from cosmos_predict1.autoregressive.configs.base.model import ModelConfig | |
from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig | |
from cosmos_predict1.autoregressive.modules.mm_projector import MultimodalProjector | |
from cosmos_predict1.autoregressive.networks.transformer import Transformer | |
from cosmos_predict1.autoregressive.networks.vit import VisionTransformer, get_vit_config | |
from cosmos_predict1.autoregressive.tokenizer.tokenizer import DiscreteMultimodalTokenizer, update_vocab_size | |
from cosmos_predict1.autoregressive.utils.checkpoint import ( | |
get_partial_state_dict, | |
obtain_tensor_parallel_state_dict, | |
process_state_dict, | |
substrings_to_ignore, | |
) | |
from cosmos_predict1.autoregressive.utils.sampling import decode_n_tokens, decode_one_token, prefill | |
from cosmos_predict1.utils import log, misc | |
def update_model_config(model_config, inference_tensor_parallel_size): | |
if inference_tensor_parallel_size > 1: | |
log.warning(f"Setting tensor parallel size to {inference_tensor_parallel_size}") | |
setattr( | |
model_config, | |
"tensor_model_parallel_size", | |
inference_tensor_parallel_size, | |
) | |
if "{rank}" in model_config.ckpt_path: | |
tp_rank = parallel_state.get_tensor_model_parallel_rank() | |
model_config.ckpt_path = model_config.ckpt_path.format(rank=tp_rank) | |
return model_config | |
class AutoRegressiveModel(torch.nn.Module): | |
""" | |
A class to build and use a AutoRegressiveModel model for text generation. | |
Methods: | |
build: Build a AutoRegressiveModel instance by initializing and loading a model checkpoint. | |
generate: Generate text sequences based on provided prompts using the language generation model. | |
""" | |
def __init__( | |
self, | |
model: Transformer = None, | |
tokenizer: DiscreteMultimodalTokenizer = None, | |
config: ModelConfig = None, | |
model_parallel: ModelParallelConfig = None, | |
vision_encoder: VisionTransformer = None, | |
mm_projector: MultimodalProjector = None, | |
): | |
""" | |
Initialize the AutoRegressiveModel instance with a model and tokenizer. | |
Args: | |
model (Transformer): The Transformer model for text generation. | |
tokenizer (Tokenizer): The tokenizer for encoding and decoding text. | |
config (Config): The configuration for the AutoRegressiveModel model. | |
model_parallel (ModelParallelConfig): The model parallel configuration for the AutoRegressiveModel model. | |
vision_encoder (VisionTransformer): The vision encoder for the AutoRegressiveModel model. | |
mm_projector (MultimodalProjector): The multi-modal projector for the AutoRegressiveModel model. | |
""" | |
super().__init__() | |
self.model = model | |
self.tokenizer = tokenizer | |
self.config = config | |
self.vision_encoder = vision_encoder | |
self.mm_projector = mm_projector | |
self.model_parallel = model_parallel | |
def precision(self): | |
return self.model.precision | |
def get_num_params( | |
self, | |
) -> int: | |
""" | |
Return the number of parameters in the model. | |
""" | |
n_params = sum(p.numel() for p in self.parameters()) | |
return n_params | |
def load_ar_model( | |
self, | |
shard_checkpoint, | |
tokenizer_config, | |
): | |
""" | |
Load the AR model. | |
""" | |
model_config = self.config | |
tensor_parallel_size = 1 if self.model_parallel is None else self.model_parallel.tensor_model_parallel_size | |
assert tensor_parallel_size == model_config["tensor_model_parallel_size"] | |
ckpt_path = model_config.ckpt_path | |
with misc.timer(f"loading checkpoint from {ckpt_path}"): | |
if ckpt_path.endswith("safetensors"): | |
# Load with safetensors API | |
checkpoint = load_file(ckpt_path, device="cpu") | |
else: | |
# The pytorch version | |
checkpoint = torch.load( | |
ckpt_path, | |
map_location="cpu", | |
mmap=True, # load the checkpoint in memory-mapped mode | |
weights_only=True, | |
) | |
llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint | |
orig_precision = torch.get_default_dtype() | |
precision = getattr(torch, model_config.precision) | |
torch.set_default_dtype(precision) | |
log.debug(f"Setting torch default dtype to {precision}") | |
model = Transformer( | |
params=model_config, | |
model_parallel=self.model_parallel, | |
tokenizer_config=tokenizer_config, | |
) | |
log.debug( | |
f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size}" | |
) | |
vocab_size = update_vocab_size( | |
existing_vocab_size=0, | |
to_be_added_vocab_size=tokenizer_config.video_tokenizer.vocab_size, | |
training_type=tokenizer_config.training_type, | |
add_special_tokens=False, | |
) | |
log.debug( | |
f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size} vocab_size {vocab_size}" | |
) | |
# Perform vocab expansion | |
if vocab_size > model.vocab_size: | |
log.debug(f"Expanding vocab size to {vocab_size}") | |
# For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer, | |
expand_output_layer = not (tokenizer_config.training_type == "text_to_video") | |
model.expand_vocab( | |
vocab_size, | |
init_method="gaussian", | |
expand_output_layer=expand_output_layer, | |
) | |
if shard_checkpoint: | |
# Shard the checkpoint according to tensor parallelism. | |
with misc.timer("sharding checkpoint according to tensor parallelism"): | |
if self.model_parallel is not None: | |
assert self.model_parallel.tensor_model_parallel_size == model_config["tensor_model_parallel_size"] | |
llm_checkpoint = obtain_tensor_parallel_state_dict( | |
llm_checkpoint, | |
tensor_parallel_size=tensor_parallel_size, | |
tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), | |
model_config=model_config, | |
) | |
# Remove the "model." prefix in the state_dict | |
llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.") | |
with misc.timer("loading state_dict into model"): | |
missing_keys, _ = model.load_state_dict(llm_checkpoint, strict=True) | |
# Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage) | |
missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")] | |
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" | |
self.model = model.to(precision).to("cuda") | |
torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value | |
def load_tokenizer(self, tokenizer_config): | |
""" | |
Load the tokenizer. | |
""" | |
self.tokenizer = DiscreteMultimodalTokenizer(tokenizer_config) | |
def build( | |
model_config: ModelConfig = ModelConfig(), | |
tokenizer_config: TokenizerConfig = None, | |
model_parallel: ModelParallelConfig = None, | |
shard_checkpoint: bool = False, | |
) -> "AutoRegressiveModel": | |
""" | |
Build a AutoRegressiveModel instance by initializing and loading a model checkpoint. | |
Args: | |
model_config (ModelConfig, optional): The model configuration for the AutoRegressiveModel instance. Defaults to ModelConfig(). | |
tokenizer_config (TokenizerConfig, optional): The tokenizer configuration for the AutoRegressiveModel instance. Defaults to None. | |
shard_checkpoint (bool, optional): Whether to split the checkpoint by Tensor Parallelism before loading. Defaults to False. | |
download_rank_sync (bool, optional): Whether to download the checkpoint in a rank-synchronized manner. Defaults to True. | |
Returns: | |
AutoRegressiveModel: An instance of the AutoRegressiveModel class with the loaded model and tokenizer. | |
Raises: | |
AssertionError: If there are no checkpoint files in the specified directory. | |
Note: | |
This method sets the device to CUDA and loads the pre-trained model and tokenizer. | |
""" | |
tensor_parallel_size = 1 if model_parallel is None else model_parallel.tensor_model_parallel_size | |
assert tensor_parallel_size == model_config["tensor_model_parallel_size"] | |
# Initialize model configuration parameters | |
config_params = {} | |
# Load checkpoint and model parameters | |
if model_config.ckpt_path is None: | |
# If ckpt_path is not provided, we assume the model checkpoint is saved in the ckpt_dir | |
ckpt_dir = model_config.ckpt_dir | |
# We prioritize safetensors version over the pytorch version, since the former is | |
# much faster for checkpoint loading. | |
checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors")) | |
if len(checkpoints) == 0: | |
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) | |
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" | |
assert ( | |
len(checkpoints) == 1 | |
), f"multiple checkpoint files found in {ckpt_dir} (currently only one is supported)" | |
ckpt_path = str(checkpoints[0]) # Assuming single checkpoint for non-parallel case | |
if os.path.exists(Path(ckpt_dir) / "config.json"): | |
with open(Path(ckpt_dir) / "config.json", "r") as f: | |
config_params = json.loads(f.read()) | |
else: | |
log.info( | |
f"No params.json found in the checkpoint directory ({ckpt_dir}). " f"Using default model config." | |
) | |
else: | |
# If ckpt_path is provided, we load the model from the specified path, | |
# and use the default model configuration | |
ckpt_path = model_config.ckpt_path | |
for key, value in config_params.items(): | |
if hasattr(model_config, key): | |
# Override the default model configuration with the parameters from the checkpoint | |
setattr(model_config, key, value) | |
with misc.timer(f"loading checkpoint from {ckpt_path}"): | |
if ckpt_path.endswith("safetensors"): | |
# Load with safetensors API | |
checkpoint = load_file(ckpt_path, device="cpu") | |
else: | |
# The pytorch version | |
checkpoint = torch.load( | |
ckpt_path, | |
map_location="cpu", | |
mmap=True, # load the checkpoint in memory-mapped mode | |
weights_only=True, | |
) | |
llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint | |
if model_config.vision_encoder is not None: | |
# Take the LLM weights (starting with "model.") from the VLM checkpoint | |
llm_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="model.") | |
if model_config.vision_encoder is not None: | |
# For vanilla VLM ckpt before fine-tuning, `checkpoint['model']` only contains LLM weights, and `checkpoint['vision_encoder']` | |
# and `checkpoint['mm_projector']` are both for those weights | |
# For fine-tuned VLM ckpt, `checkpoint['model']` contains all LLM, mm_projector and vision_encoder weights | |
if "vision_encoder" in checkpoint: | |
log.debug("Using pretrained vision_encoder") | |
vit_checkpoint = checkpoint["vision_encoder"] | |
else: | |
log.debug("Using fine-tuned vision_encoder") | |
vit_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="vision_encoder.") | |
vit_checkpoint = process_state_dict(vit_checkpoint, prefix_to_remove="vision_encoder.") | |
if "mm_projector" in checkpoint: | |
log.debug("Using pretrained mm_projector") | |
projector_checkpoint = checkpoint["mm_projector"] | |
else: | |
log.debug("Using fine-tuned mm_projector") | |
projector_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="mm_projector.") | |
projector_checkpoint = process_state_dict(projector_checkpoint, prefix_to_remove="mm_projector.") | |
assert ( | |
len(vit_checkpoint) > 0 and len(projector_checkpoint) > 0 | |
), "vit_checkpoint and projector_checkpoint cannot be empty. We do not support random initialization for vision_encoder and mm_projector." | |
tokenizer = DiscreteMultimodalTokenizer(tokenizer_config) | |
orig_precision = torch.get_default_dtype() | |
precision = getattr(torch, model_config.precision) | |
torch.set_default_dtype(precision) | |
log.debug(f"Setting torch default dtype to {precision}") | |
model = Transformer( | |
params=model_config, | |
model_parallel=model_parallel, | |
tokenizer_config=tokenizer_config, | |
) | |
model_kwargs = {} | |
if model_config.vision_encoder is not None: | |
assert model_config.mm_projector is not None, "mm_projector must be provided if vision_encoder is provided." | |
vit_config = get_vit_config(model_config.vision_encoder) | |
vit_config["tensor_model_parallel_size"] = tensor_parallel_size | |
vision_encoder = VisionTransformer.build( | |
vit_config, | |
) | |
mm_projector = MultimodalProjector( | |
mm_projector_type=model_config.mm_projector, in_dim=vit_config["dim"], out_dim=model_config["dim"] | |
) | |
model_kwargs.update({"vision_encoder": vision_encoder, "mm_projector": mm_projector}) | |
# Perform vocab expansion | |
if tokenizer.vocab_size > model.vocab_size: | |
log.debug(f"Expanding vocab size to {tokenizer.vocab_size}") | |
# For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer, | |
expand_output_layer = not (tokenizer.training_type == "text_to_video") | |
model.expand_vocab( | |
tokenizer.vocab_size, | |
init_method="gaussian", | |
expand_output_layer=expand_output_layer, | |
) | |
if shard_checkpoint: | |
# Shard the checkpoint according to tensor parallelism. | |
with misc.timer("sharding checkpoint according to tensor parallelism"): | |
if model_parallel is not None: | |
assert model_parallel.tensor_model_parallel_size == model_config["tensor_model_parallel_size"] | |
llm_checkpoint = obtain_tensor_parallel_state_dict( | |
llm_checkpoint, | |
tensor_parallel_size=tensor_parallel_size, | |
tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), | |
model_config=model_config, | |
) | |
if model_config.vision_encoder is not None: | |
# Shard vision encoder and multimodal projector weights | |
vit_checkpoint = obtain_tensor_parallel_state_dict( | |
vit_checkpoint, | |
tensor_parallel_size=tensor_parallel_size, | |
tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), | |
model_config=vit_config, | |
) | |
# Remove the "model." prefix in the state_dict | |
llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.") | |
with misc.timer("loading state_dict into model"): | |
missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True) | |
# Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage) | |
missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")] | |
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" | |
if model_config.vision_encoder is not None: | |
vision_encoder.load_state_dict(vit_checkpoint) | |
mm_projector.load_state_dict(projector_checkpoint) | |
if model_config.vision_encoder_in_channels != 3: | |
vision_encoder.expand_in_channels(model_config.vision_encoder_in_channels) | |
model = model.to(precision) # ensure model parameters are in the correct precision | |
log.debug(f"Model config: {model_config}") | |
model_class = AutoRegressiveModel | |
torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value | |
return model_class(model, tokenizer, model_config, **model_kwargs) | |
def generate( | |
self, | |
prompt_tokens: List[List[int]] | torch.Tensor, | |
max_gen_len: int, | |
temperature: float = 1.0, | |
top_k: Optional[int] = None, | |
top_p: Optional[float] = None, | |
num_gen_seq: int = 1, | |
logprobs: bool = False, | |
echo: bool = False, | |
seed: int = None, | |
context: Optional[torch.Tensor] = None, | |
context_mask: Optional[torch.Tensor] = None, | |
compile_sampling: bool = True, | |
compile_prefill: bool = False, | |
verbose: bool = True, | |
stop_tokens: Optional[Set[int]] = None, | |
images: Optional[torch.Tensor] = None, | |
): | |
""" | |
Autoregressive generation built upon the gpt-fast implementation (https://github.com/pytorch-labs/gpt-fast). | |
Args: | |
prompt_tokens (List[List[int]] | torch.Tensor): A single prompt of shape (1, seq_len). | |
max_gen_len (int): Maximum length of the generated text sequence. | |
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. | |
top_k (int, optional): Top-k value for top-k sampling. Defaults to None. | |
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to None. | |
num_gen_seq (int, optional): Number of outputs to generate given the same prompt. Defaults to 1. When temperature == 0, num_gen_seq must be 1 because the generation is deterministic. | |
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. | |
logit_clipping_range (list, optional): Range of logits to clip. Defaults to []. | |
seed (int, optional): Random seed for reproducibility. Defaults to None. | |
compile_sampling (bool, optional): Flag indicating whether to compile the decoding function. Defaults to True. | |
compile_prefill (bool, optional): Flag indicating whether to compile the prefill function. Defaults to False. | |
verbose (bool, optional): Flag indicating whether to print the the time. Defaults to False. | |
""" | |
assert top_k is None or top_p is None, f"Only one of top_k ({top_k} or top_p ({top_p} should be specified." | |
if temperature == 0: | |
top_p, top_k = None, None | |
log.debug("Setting top_p and top_k to None because temperature is 0") | |
if top_p is not None: | |
log.debug(f"Using top-p sampling with p={top_p} and temperature={temperature}") | |
elif top_k is not None: | |
log.debug(f"Using top-k sampling with k={top_k} and temperature={temperature}") | |
else: | |
log.debug("Not applying top-k or top-p sampling. Will use top-k sampling with k=None") | |
orig_precision = torch.get_default_dtype() | |
torch.set_default_dtype(self.precision) | |
torch._inductor.config.coordinate_descent_tuning = True | |
torch._inductor.config.triton.unique_kernel_names = True | |
# Experimental features to reduce compilation times, will be on by default in future | |
torch._inductor.config.fx_graph_cache = True | |
if seed is not None: | |
misc.set_random_seed(seed) | |
assert not logprobs, "logprobs are not supported for fast_generate yet" | |
# Examine if the function prefil and decode_one_token functions are compiled yet. If not, compile them based on the flags | |
if compile_sampling and not getattr(self, "inference_decode_compiled", False): | |
log.info("Compiling AR sampling function. Note: the first run will be slower due to compilation") | |
self.decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) | |
self.inference_decode_compiled = True | |
log.info("Compiled AR sampling function.") | |
if compile_prefill and not getattr(self, "inference_prefill_compiled", False): | |
log.info("Compiling prefill function. Note: the first run will be slower due to compilation") | |
self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True) | |
self.inference_prefill_compiled = True | |
log.info("Compiled prefill function.") | |
if not hasattr(self, "decode_one_token"): | |
self.decode_one_token = decode_one_token | |
if not hasattr(self, "prefill"): | |
self.prefill = prefill | |
# Initialization and Assertions | |
if isinstance(self.model.params, list): | |
# During training, model.params is a list | |
log.debug( | |
f"Find self.model.params is a list, use self.config instead. Get max_batch_size={self.config.max_batch_size}, max_seq_len={self.config.max_seq_len}" | |
) | |
params = self.config | |
else: | |
params = self.model.params | |
if isinstance(prompt_tokens, list): | |
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cuda") | |
if prompt_tokens.ndim == 1: | |
prompt_tokens = prompt_tokens.view(1, -1) | |
else: | |
assert prompt_tokens.ndim == 2, f"prompt_tokens has shape {prompt_tokens.shape}" | |
batch_size, prompt_len = prompt_tokens.shape | |
total_len = min(params.max_seq_len, max_gen_len + prompt_len) | |
if max_gen_len + prompt_len > params.max_seq_len: | |
log.warning( | |
f"max_gen_len + prompt_len={max_gen_len + prompt_len} exceeds max_seq_len={params.max_seq_len}, truncate max_gen_len to {params.max_seq_len - prompt_len}" | |
) | |
max_gen_len = params.max_seq_len - prompt_len | |
if context_mask is not None: | |
context_mask = context_mask.to(dtype=torch.bool) | |
if context_mask.ndim == 2: | |
assert ( | |
context_mask.shape[0] == batch_size | |
), f"batch_size mismatch: {context_mask.shape[0]} != {batch_size}" | |
# Unsqueeze it to make it of shape [batch_size, 1, 1, context_seq_len] | |
context_mask = context_mask.view(batch_size, 1, 1, -1) | |
if num_gen_seq > 1: | |
assert ( | |
batch_size == 1 | |
), f"num_gen_seq > 1 is only supported for a single prompt, got {len(prompt_tokens)} prompts" | |
log.debug(f"Generating {num_gen_seq} sequences with the same prompt") | |
assert ( | |
num_gen_seq <= params.max_batch_size | |
), f"num_gen_seq={num_gen_seq} exceeds max_batch_size={params.max_batch_size}" | |
# repeat the prompt tokens for num_gen_seq times | |
prompt_tokens = prompt_tokens.repeat(num_gen_seq, 1) | |
assert prompt_tokens.shape == ( | |
num_gen_seq, | |
prompt_len, | |
), f"prompt_tokens must be of shape (num_gen_seq, seq_len), got {prompt_tokens.shape}" | |
batch_size = len(prompt_tokens) | |
# create an empty tensor of the expected final shape and fill in the current tokens | |
empty = torch.empty(batch_size, total_len, dtype=prompt_tokens.dtype, device=prompt_tokens.device) | |
empty[:, :prompt_len] = prompt_tokens | |
seq = empty | |
input_pos = torch.arange(0, prompt_len, device="cuda") | |
if verbose: | |
prefill_start = time.time() | |
if images is not None: | |
images = images.to(device=prompt_tokens.device, dtype=torch.bfloat16) | |
prompt_token_embeddings = self.embed_vision_language_features(prompt_tokens, images) | |
else: | |
prompt_token_embeddings = None | |
if context is not None: | |
context = context.to(device=prompt_tokens.device, dtype=self.precision) | |
# Prefill stage | |
next_token = self.prefill( | |
self.model, | |
input_pos=input_pos, | |
tokens=prompt_tokens if prompt_token_embeddings is None else None, | |
token_embeddings=prompt_token_embeddings, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
context=context, | |
context_mask=context_mask, | |
) | |
if verbose: | |
prefill_time = time.time() - prefill_start | |
seq[:, [prompt_len]] = next_token.to(dtype=seq.dtype) | |
input_pos = torch.tensor([prompt_len], dtype=torch.long, device="cuda") | |
stop_tokens = self.tokenizer.stop_tokens if stop_tokens is None else stop_tokens | |
stop_tokens = torch.tensor(list(stop_tokens), dtype=torch.long, device="cuda") | |
if verbose: | |
decode_start = time.time() | |
# Decode stage | |
generated_tokens = decode_n_tokens( | |
self.model, | |
next_token.view(batch_size, -1), | |
input_pos, | |
max_gen_len - 1, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
stop_tokens=stop_tokens, | |
decode_one_token_function=self.decode_one_token, | |
context=context, | |
context_mask=context_mask, | |
) | |
gen_len = len(generated_tokens) | |
if verbose: | |
decode_time = time.time() - decode_start | |
prefill_throughput = prompt_len / prefill_time | |
decode_throughput = gen_len / decode_time | |
log.debug(f"[Prefill] Time: {prefill_time:.2f}s; Throughput: {prefill_throughput:.2f} tokens/s") | |
log.debug(f"[Decode] Time: {decode_time:.2f}s; Throughput: {decode_throughput:.2f} tokens/s") | |
generated_tokens = torch.cat(generated_tokens, dim=1) | |
log.debug(f"generated_tokens: {generated_tokens.shape}") | |
seq = seq[:, : prompt_len + 1 + gen_len] | |
seq[:, prompt_len + 1 :] = generated_tokens | |
if not echo: | |
seq = seq[:, prompt_len:] | |
torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value | |
return seq, None | |
def embed_vision_language_features(self, input_ids: torch.Tensor, images: torch.tensor) -> torch.Tensor: | |
""" | |
Embed vision and language features into a combined representation. | |
Args: | |
input_ids (torch.Tensor): Input token IDs. | |
images (torch.tensor): Input images. | |
Returns: | |
torch.Tensor: Combined vision-language features. | |
Raises: | |
AssertionError: If vision encoder or mm projector is not initialized, | |
or if dimensions mismatch. | |
""" | |
# Ensure vision encoder and mm projector are initialized | |
assert self.vision_encoder is not None | |
assert self.mm_projector is not None | |
# Get image token ID and validate it | |
image_token_id = self.vision_encoder.image_token_id | |
assert isinstance(image_token_id, int) and image_token_id >= 0, f"Invalid image_token_id: {image_token_id}" | |
# Identify text and image locations in the input | |
text_locations = input_ids != image_token_id | |
image_locations = input_ids == image_token_id | |
# Process text features | |
text_features = self.model.tok_embeddings(input_ids[text_locations]) | |
# Process image features | |
images = images.to(device=text_features.device, dtype=text_features.dtype) | |
vit_outputs = self.vision_encoder(images) | |
image_features = self.mm_projector(vit_outputs) | |
# Get dimensions | |
B, seq_len = input_ids.shape | |
N_total = B * seq_len | |
N_txt, D_txt = text_features.shape | |
N_img, N_patch, D_img = image_features.shape | |
# Reshape image features | |
image_features = image_features.reshape(N_img * N_patch, D_img) | |
# Validate dimensions | |
assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}" | |
assert ( | |
N_total == N_txt + N_img * N_patch | |
), f"seq_len {seq_len} should be equal to N_txt + N_img*N_Patch {(N_txt, N_img * N_patch, image_locations.sum().item())}" | |
# Combine text and image features | |
combined_features = torch.empty( | |
(B, seq_len, D_txt), | |
dtype=text_features.dtype, | |
device=text_features.device, | |
) | |
combined_features[text_locations, :] = text_features | |
combined_features[image_locations, :] = image_features | |
return combined_features | |
def state_dict(self, *args, **kwargs): | |
""" | |
Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8). | |
""" | |
state_dict = super().state_dict(*args, **kwargs) | |
return process_state_dict(state_dict) | |
def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False): | |
""" | |
Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by | |
TransformerEngine for FP8). | |
""" | |
state_dict = process_state_dict(state_dict) | |
missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign) | |
actual_missing_keys = [] | |
for key in missing_keys: | |
if not any(substring in key for substring in substrings_to_ignore): | |
actual_missing_keys.append(key) | |
if strict: | |
if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0: | |
raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}") | |
return _IncompatibleKeys(actual_missing_keys, unexpected_keys) | |