# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import time from pathlib import Path from typing import Any, Dict, List, Optional, Set import torch from megatron.core import ModelParallelConfig, parallel_state from safetensors.torch import load_file from torch.nn.modules.module import _IncompatibleKeys from cosmos_predict1.autoregressive.configs.base.model import ModelConfig from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig from cosmos_predict1.autoregressive.modules.mm_projector import MultimodalProjector from cosmos_predict1.autoregressive.networks.transformer import Transformer from cosmos_predict1.autoregressive.networks.vit import VisionTransformer, get_vit_config from cosmos_predict1.autoregressive.tokenizer.tokenizer import DiscreteMultimodalTokenizer, update_vocab_size from cosmos_predict1.autoregressive.utils.checkpoint import ( get_partial_state_dict, obtain_tensor_parallel_state_dict, process_state_dict, substrings_to_ignore, ) from cosmos_predict1.autoregressive.utils.sampling import decode_n_tokens, decode_one_token, prefill from cosmos_predict1.utils import log, misc def update_model_config(model_config, inference_tensor_parallel_size): if inference_tensor_parallel_size > 1: log.warning(f"Setting tensor parallel size to {inference_tensor_parallel_size}") setattr( model_config, "tensor_model_parallel_size", inference_tensor_parallel_size, ) if "{rank}" in model_config.ckpt_path: tp_rank = parallel_state.get_tensor_model_parallel_rank() model_config.ckpt_path = model_config.ckpt_path.format(rank=tp_rank) return model_config class AutoRegressiveModel(torch.nn.Module): """ A class to build and use a AutoRegressiveModel model for text generation. Methods: build: Build a AutoRegressiveModel instance by initializing and loading a model checkpoint. generate: Generate text sequences based on provided prompts using the language generation model. """ def __init__( self, model: Transformer = None, tokenizer: DiscreteMultimodalTokenizer = None, config: ModelConfig = None, model_parallel: ModelParallelConfig = None, vision_encoder: VisionTransformer = None, mm_projector: MultimodalProjector = None, ): """ Initialize the AutoRegressiveModel instance with a model and tokenizer. Args: model (Transformer): The Transformer model for text generation. tokenizer (Tokenizer): The tokenizer for encoding and decoding text. config (Config): The configuration for the AutoRegressiveModel model. model_parallel (ModelParallelConfig): The model parallel configuration for the AutoRegressiveModel model. vision_encoder (VisionTransformer): The vision encoder for the AutoRegressiveModel model. mm_projector (MultimodalProjector): The multi-modal projector for the AutoRegressiveModel model. """ super().__init__() self.model = model self.tokenizer = tokenizer self.config = config self.vision_encoder = vision_encoder self.mm_projector = mm_projector self.model_parallel = model_parallel @property def precision(self): return self.model.precision def get_num_params( self, ) -> int: """ Return the number of parameters in the model. """ n_params = sum(p.numel() for p in self.parameters()) return n_params def load_ar_model( self, shard_checkpoint, tokenizer_config, ): """ Load the AR model. """ model_config = self.config tensor_parallel_size = 1 if self.model_parallel is None else self.model_parallel.tensor_model_parallel_size assert tensor_parallel_size == model_config["tensor_model_parallel_size"] ckpt_path = model_config.ckpt_path with misc.timer(f"loading checkpoint from {ckpt_path}"): if ckpt_path.endswith("safetensors"): # Load with safetensors API checkpoint = load_file(ckpt_path, device="cpu") else: # The pytorch version checkpoint = torch.load( ckpt_path, map_location="cpu", mmap=True, # load the checkpoint in memory-mapped mode weights_only=True, ) llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint orig_precision = torch.get_default_dtype() precision = getattr(torch, model_config.precision) torch.set_default_dtype(precision) log.debug(f"Setting torch default dtype to {precision}") model = Transformer( params=model_config, model_parallel=self.model_parallel, tokenizer_config=tokenizer_config, ) log.debug( f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size}" ) vocab_size = update_vocab_size( existing_vocab_size=0, to_be_added_vocab_size=tokenizer_config.video_tokenizer.vocab_size, training_type=tokenizer_config.training_type, add_special_tokens=False, ) log.debug( f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size} vocab_size {vocab_size}" ) # Perform vocab expansion if vocab_size > model.vocab_size: log.debug(f"Expanding vocab size to {vocab_size}") # For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer, expand_output_layer = not (tokenizer_config.training_type == "text_to_video") model.expand_vocab( vocab_size, init_method="gaussian", expand_output_layer=expand_output_layer, ) if shard_checkpoint: # Shard the checkpoint according to tensor parallelism. with misc.timer("sharding checkpoint according to tensor parallelism"): if self.model_parallel is not None: assert self.model_parallel.tensor_model_parallel_size == model_config["tensor_model_parallel_size"] llm_checkpoint = obtain_tensor_parallel_state_dict( llm_checkpoint, tensor_parallel_size=tensor_parallel_size, tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), model_config=model_config, ) # Remove the "model." prefix in the state_dict llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.") with misc.timer("loading state_dict into model"): missing_keys, _ = model.load_state_dict(llm_checkpoint, strict=True) # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage) missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")] assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" self.model = model.to(precision).to("cuda") torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value def load_tokenizer(self, tokenizer_config): """ Load the tokenizer. """ self.tokenizer = DiscreteMultimodalTokenizer(tokenizer_config) @staticmethod def build( model_config: ModelConfig = ModelConfig(), tokenizer_config: TokenizerConfig = None, model_parallel: ModelParallelConfig = None, shard_checkpoint: bool = False, ) -> "AutoRegressiveModel": """ Build a AutoRegressiveModel instance by initializing and loading a model checkpoint. Args: model_config (ModelConfig, optional): The model configuration for the AutoRegressiveModel instance. Defaults to ModelConfig(). tokenizer_config (TokenizerConfig, optional): The tokenizer configuration for the AutoRegressiveModel instance. Defaults to None. shard_checkpoint (bool, optional): Whether to split the checkpoint by Tensor Parallelism before loading. Defaults to False. download_rank_sync (bool, optional): Whether to download the checkpoint in a rank-synchronized manner. Defaults to True. Returns: AutoRegressiveModel: An instance of the AutoRegressiveModel class with the loaded model and tokenizer. Raises: AssertionError: If there are no checkpoint files in the specified directory. Note: This method sets the device to CUDA and loads the pre-trained model and tokenizer. """ tensor_parallel_size = 1 if model_parallel is None else model_parallel.tensor_model_parallel_size assert tensor_parallel_size == model_config["tensor_model_parallel_size"] # Initialize model configuration parameters config_params = {} # Load checkpoint and model parameters if model_config.ckpt_path is None: # If ckpt_path is not provided, we assume the model checkpoint is saved in the ckpt_dir ckpt_dir = model_config.ckpt_dir # We prioritize safetensors version over the pytorch version, since the former is # much faster for checkpoint loading. checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors")) if len(checkpoints) == 0: checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert ( len(checkpoints) == 1 ), f"multiple checkpoint files found in {ckpt_dir} (currently only one is supported)" ckpt_path = str(checkpoints[0]) # Assuming single checkpoint for non-parallel case if os.path.exists(Path(ckpt_dir) / "config.json"): with open(Path(ckpt_dir) / "config.json", "r") as f: config_params = json.loads(f.read()) else: log.info( f"No params.json found in the checkpoint directory ({ckpt_dir}). " f"Using default model config." ) else: # If ckpt_path is provided, we load the model from the specified path, # and use the default model configuration ckpt_path = model_config.ckpt_path for key, value in config_params.items(): if hasattr(model_config, key): # Override the default model configuration with the parameters from the checkpoint setattr(model_config, key, value) with misc.timer(f"loading checkpoint from {ckpt_path}"): if ckpt_path.endswith("safetensors"): # Load with safetensors API checkpoint = load_file(ckpt_path, device="cpu") else: # The pytorch version checkpoint = torch.load( ckpt_path, map_location="cpu", mmap=True, # load the checkpoint in memory-mapped mode weights_only=True, ) llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint if model_config.vision_encoder is not None: # Take the LLM weights (starting with "model.") from the VLM checkpoint llm_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="model.") if model_config.vision_encoder is not None: # For vanilla VLM ckpt before fine-tuning, `checkpoint['model']` only contains LLM weights, and `checkpoint['vision_encoder']` # and `checkpoint['mm_projector']` are both for those weights # For fine-tuned VLM ckpt, `checkpoint['model']` contains all LLM, mm_projector and vision_encoder weights if "vision_encoder" in checkpoint: log.debug("Using pretrained vision_encoder") vit_checkpoint = checkpoint["vision_encoder"] else: log.debug("Using fine-tuned vision_encoder") vit_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="vision_encoder.") vit_checkpoint = process_state_dict(vit_checkpoint, prefix_to_remove="vision_encoder.") if "mm_projector" in checkpoint: log.debug("Using pretrained mm_projector") projector_checkpoint = checkpoint["mm_projector"] else: log.debug("Using fine-tuned mm_projector") projector_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="mm_projector.") projector_checkpoint = process_state_dict(projector_checkpoint, prefix_to_remove="mm_projector.") assert ( len(vit_checkpoint) > 0 and len(projector_checkpoint) > 0 ), "vit_checkpoint and projector_checkpoint cannot be empty. We do not support random initialization for vision_encoder and mm_projector." tokenizer = DiscreteMultimodalTokenizer(tokenizer_config) orig_precision = torch.get_default_dtype() precision = getattr(torch, model_config.precision) torch.set_default_dtype(precision) log.debug(f"Setting torch default dtype to {precision}") model = Transformer( params=model_config, model_parallel=model_parallel, tokenizer_config=tokenizer_config, ) model_kwargs = {} if model_config.vision_encoder is not None: assert model_config.mm_projector is not None, "mm_projector must be provided if vision_encoder is provided." vit_config = get_vit_config(model_config.vision_encoder) vit_config["tensor_model_parallel_size"] = tensor_parallel_size vision_encoder = VisionTransformer.build( vit_config, ) mm_projector = MultimodalProjector( mm_projector_type=model_config.mm_projector, in_dim=vit_config["dim"], out_dim=model_config["dim"] ) model_kwargs.update({"vision_encoder": vision_encoder, "mm_projector": mm_projector}) # Perform vocab expansion if tokenizer.vocab_size > model.vocab_size: log.debug(f"Expanding vocab size to {tokenizer.vocab_size}") # For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer, expand_output_layer = not (tokenizer.training_type == "text_to_video") model.expand_vocab( tokenizer.vocab_size, init_method="gaussian", expand_output_layer=expand_output_layer, ) if shard_checkpoint: # Shard the checkpoint according to tensor parallelism. with misc.timer("sharding checkpoint according to tensor parallelism"): if model_parallel is not None: assert model_parallel.tensor_model_parallel_size == model_config["tensor_model_parallel_size"] llm_checkpoint = obtain_tensor_parallel_state_dict( llm_checkpoint, tensor_parallel_size=tensor_parallel_size, tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), model_config=model_config, ) if model_config.vision_encoder is not None: # Shard vision encoder and multimodal projector weights vit_checkpoint = obtain_tensor_parallel_state_dict( vit_checkpoint, tensor_parallel_size=tensor_parallel_size, tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), model_config=vit_config, ) # Remove the "model." prefix in the state_dict llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.") with misc.timer("loading state_dict into model"): missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True) # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage) missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")] assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" if model_config.vision_encoder is not None: vision_encoder.load_state_dict(vit_checkpoint) mm_projector.load_state_dict(projector_checkpoint) if model_config.vision_encoder_in_channels != 3: vision_encoder.expand_in_channels(model_config.vision_encoder_in_channels) model = model.to(precision) # ensure model parameters are in the correct precision log.debug(f"Model config: {model_config}") model_class = AutoRegressiveModel torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value return model_class(model, tokenizer, model_config, **model_kwargs) @torch.no_grad() def generate( self, prompt_tokens: List[List[int]] | torch.Tensor, max_gen_len: int, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, num_gen_seq: int = 1, logprobs: bool = False, echo: bool = False, seed: int = None, context: Optional[torch.Tensor] = None, context_mask: Optional[torch.Tensor] = None, compile_sampling: bool = True, compile_prefill: bool = False, verbose: bool = True, stop_tokens: Optional[Set[int]] = None, images: Optional[torch.Tensor] = None, ): """ Autoregressive generation built upon the gpt-fast implementation (https://github.com/pytorch-labs/gpt-fast). Args: prompt_tokens (List[List[int]] | torch.Tensor): A single prompt of shape (1, seq_len). max_gen_len (int): Maximum length of the generated text sequence. temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. top_k (int, optional): Top-k value for top-k sampling. Defaults to None. top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to None. num_gen_seq (int, optional): Number of outputs to generate given the same prompt. Defaults to 1. When temperature == 0, num_gen_seq must be 1 because the generation is deterministic. echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. logit_clipping_range (list, optional): Range of logits to clip. Defaults to []. seed (int, optional): Random seed for reproducibility. Defaults to None. compile_sampling (bool, optional): Flag indicating whether to compile the decoding function. Defaults to True. compile_prefill (bool, optional): Flag indicating whether to compile the prefill function. Defaults to False. verbose (bool, optional): Flag indicating whether to print the the time. Defaults to False. """ assert top_k is None or top_p is None, f"Only one of top_k ({top_k} or top_p ({top_p} should be specified." if temperature == 0: top_p, top_k = None, None log.debug("Setting top_p and top_k to None because temperature is 0") if top_p is not None: log.debug(f"Using top-p sampling with p={top_p} and temperature={temperature}") elif top_k is not None: log.debug(f"Using top-k sampling with k={top_k} and temperature={temperature}") else: log.debug("Not applying top-k or top-p sampling. Will use top-k sampling with k=None") orig_precision = torch.get_default_dtype() torch.set_default_dtype(self.precision) torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.triton.unique_kernel_names = True # Experimental features to reduce compilation times, will be on by default in future torch._inductor.config.fx_graph_cache = True if seed is not None: misc.set_random_seed(seed) assert not logprobs, "logprobs are not supported for fast_generate yet" # Examine if the function prefil and decode_one_token functions are compiled yet. If not, compile them based on the flags if compile_sampling and not getattr(self, "inference_decode_compiled", False): log.info("Compiling AR sampling function. Note: the first run will be slower due to compilation") self.decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) self.inference_decode_compiled = True log.info("Compiled AR sampling function.") if compile_prefill and not getattr(self, "inference_prefill_compiled", False): log.info("Compiling prefill function. Note: the first run will be slower due to compilation") self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True) self.inference_prefill_compiled = True log.info("Compiled prefill function.") if not hasattr(self, "decode_one_token"): self.decode_one_token = decode_one_token if not hasattr(self, "prefill"): self.prefill = prefill # Initialization and Assertions if isinstance(self.model.params, list): # During training, model.params is a list log.debug( f"Find self.model.params is a list, use self.config instead. Get max_batch_size={self.config.max_batch_size}, max_seq_len={self.config.max_seq_len}" ) params = self.config else: params = self.model.params if isinstance(prompt_tokens, list): prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cuda") if prompt_tokens.ndim == 1: prompt_tokens = prompt_tokens.view(1, -1) else: assert prompt_tokens.ndim == 2, f"prompt_tokens has shape {prompt_tokens.shape}" batch_size, prompt_len = prompt_tokens.shape total_len = min(params.max_seq_len, max_gen_len + prompt_len) if max_gen_len + prompt_len > params.max_seq_len: log.warning( f"max_gen_len + prompt_len={max_gen_len + prompt_len} exceeds max_seq_len={params.max_seq_len}, truncate max_gen_len to {params.max_seq_len - prompt_len}" ) max_gen_len = params.max_seq_len - prompt_len if context_mask is not None: context_mask = context_mask.to(dtype=torch.bool) if context_mask.ndim == 2: assert ( context_mask.shape[0] == batch_size ), f"batch_size mismatch: {context_mask.shape[0]} != {batch_size}" # Unsqueeze it to make it of shape [batch_size, 1, 1, context_seq_len] context_mask = context_mask.view(batch_size, 1, 1, -1) if num_gen_seq > 1: assert ( batch_size == 1 ), f"num_gen_seq > 1 is only supported for a single prompt, got {len(prompt_tokens)} prompts" log.debug(f"Generating {num_gen_seq} sequences with the same prompt") assert ( num_gen_seq <= params.max_batch_size ), f"num_gen_seq={num_gen_seq} exceeds max_batch_size={params.max_batch_size}" # repeat the prompt tokens for num_gen_seq times prompt_tokens = prompt_tokens.repeat(num_gen_seq, 1) assert prompt_tokens.shape == ( num_gen_seq, prompt_len, ), f"prompt_tokens must be of shape (num_gen_seq, seq_len), got {prompt_tokens.shape}" batch_size = len(prompt_tokens) # create an empty tensor of the expected final shape and fill in the current tokens empty = torch.empty(batch_size, total_len, dtype=prompt_tokens.dtype, device=prompt_tokens.device) empty[:, :prompt_len] = prompt_tokens seq = empty input_pos = torch.arange(0, prompt_len, device="cuda") if verbose: prefill_start = time.time() if images is not None: images = images.to(device=prompt_tokens.device, dtype=torch.bfloat16) prompt_token_embeddings = self.embed_vision_language_features(prompt_tokens, images) else: prompt_token_embeddings = None if context is not None: context = context.to(device=prompt_tokens.device, dtype=self.precision) # Prefill stage next_token = self.prefill( self.model, input_pos=input_pos, tokens=prompt_tokens if prompt_token_embeddings is None else None, token_embeddings=prompt_token_embeddings, temperature=temperature, top_k=top_k, top_p=top_p, context=context, context_mask=context_mask, ) if verbose: prefill_time = time.time() - prefill_start seq[:, [prompt_len]] = next_token.to(dtype=seq.dtype) input_pos = torch.tensor([prompt_len], dtype=torch.long, device="cuda") stop_tokens = self.tokenizer.stop_tokens if stop_tokens is None else stop_tokens stop_tokens = torch.tensor(list(stop_tokens), dtype=torch.long, device="cuda") if verbose: decode_start = time.time() # Decode stage generated_tokens = decode_n_tokens( self.model, next_token.view(batch_size, -1), input_pos, max_gen_len - 1, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens, decode_one_token_function=self.decode_one_token, context=context, context_mask=context_mask, ) gen_len = len(generated_tokens) if verbose: decode_time = time.time() - decode_start prefill_throughput = prompt_len / prefill_time decode_throughput = gen_len / decode_time log.debug(f"[Prefill] Time: {prefill_time:.2f}s; Throughput: {prefill_throughput:.2f} tokens/s") log.debug(f"[Decode] Time: {decode_time:.2f}s; Throughput: {decode_throughput:.2f} tokens/s") generated_tokens = torch.cat(generated_tokens, dim=1) log.debug(f"generated_tokens: {generated_tokens.shape}") seq = seq[:, : prompt_len + 1 + gen_len] seq[:, prompt_len + 1 :] = generated_tokens if not echo: seq = seq[:, prompt_len:] torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value return seq, None def embed_vision_language_features(self, input_ids: torch.Tensor, images: torch.tensor) -> torch.Tensor: """ Embed vision and language features into a combined representation. Args: input_ids (torch.Tensor): Input token IDs. images (torch.tensor): Input images. Returns: torch.Tensor: Combined vision-language features. Raises: AssertionError: If vision encoder or mm projector is not initialized, or if dimensions mismatch. """ # Ensure vision encoder and mm projector are initialized assert self.vision_encoder is not None assert self.mm_projector is not None # Get image token ID and validate it image_token_id = self.vision_encoder.image_token_id assert isinstance(image_token_id, int) and image_token_id >= 0, f"Invalid image_token_id: {image_token_id}" # Identify text and image locations in the input text_locations = input_ids != image_token_id image_locations = input_ids == image_token_id # Process text features text_features = self.model.tok_embeddings(input_ids[text_locations]) # Process image features images = images.to(device=text_features.device, dtype=text_features.dtype) vit_outputs = self.vision_encoder(images) image_features = self.mm_projector(vit_outputs) # Get dimensions B, seq_len = input_ids.shape N_total = B * seq_len N_txt, D_txt = text_features.shape N_img, N_patch, D_img = image_features.shape # Reshape image features image_features = image_features.reshape(N_img * N_patch, D_img) # Validate dimensions assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}" assert ( N_total == N_txt + N_img * N_patch ), f"seq_len {seq_len} should be equal to N_txt + N_img*N_Patch {(N_txt, N_img * N_patch, image_locations.sum().item())}" # Combine text and image features combined_features = torch.empty( (B, seq_len, D_txt), dtype=text_features.dtype, device=text_features.device, ) combined_features[text_locations, :] = text_features combined_features[image_locations, :] = image_features return combined_features def state_dict(self, *args, **kwargs): """ Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8). """ state_dict = super().state_dict(*args, **kwargs) return process_state_dict(state_dict) def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False): """ Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by TransformerEngine for FP8). """ state_dict = process_state_dict(state_dict) missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign) actual_missing_keys = [] for key in missing_keys: if not any(substring in key for substring in substrings_to_ignore): actual_missing_keys.append(key) if strict: if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0: raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}") return _IncompatibleKeys(actual_missing_keys, unexpected_keys)