Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Any, Dict, Optional | |
import torch | |
from cosmos_predict1.utils import log | |
# Substrings to ignore when processing state dicts | |
substrings_to_ignore = [ | |
"_extra_state", # Extra states (BytesIO type) added by TransformerEngine for FP8 handling | |
] | |
def identify_checkpoint_backend(state_dict: dict[str, torch.Tensor]) -> str: | |
""" | |
Identify the backend of the checkpoint (PyTorch or TransformerEngine) | |
Args: | |
state_dict (dict[str, torch.Tensor]): The state dict to check | |
Returns: | |
str: The backend of the checkpoint | |
""" | |
for key in state_dict.keys(): | |
if "self_attention.layernorm_qkv.query_weight" in key: | |
return "transformer_engine" | |
elif "attention.wq.weight" in key: | |
return "pytorch" | |
raise ValueError("Could not identify the backend of the checkpoint") | |
def get_partial_state_dict( | |
state_dict: dict[str, torch.Tensor], | |
prefix: str, | |
) -> dict[str, torch.Tensor]: | |
""" | |
Get a partial state dict with keys starting with the given prefix | |
""" | |
return {k: v for k, v in state_dict.items() if k.startswith(prefix)} | |
def process_state_dict( | |
state_dict: dict[str, torch.Tensor], | |
device: str = None, | |
dtype: torch.dtype = None, | |
prefix_to_remove: Optional[str] = None, | |
) -> dict[str, torch.Tensor]: | |
""" | |
- Remove items with substring "_extra_state" in keys (TransformerEngine adds these for FP8) | |
- Move tensors to specified device and dtype if provided | |
Args: | |
state_dict (dict[str, torch.Tensor]): The state dict to process | |
device (str, optional): The device to move tensors to. Defaults to None. | |
dtype (torch.dtype, optional): The dtype to move tensors to. Defaults to None. | |
prefix_to_remove (str, optional): The prefix to remove from the keys of the state dict. Defaults to None. | |
Returns: | |
dict[str, torch.Tensor]: The processed state dict | |
""" | |
new_state_dict = {} | |
tensor_kwargs = {} | |
if device is not None: | |
tensor_kwargs["device"] = device | |
if dtype is not None: | |
tensor_kwargs["dtype"] = dtype | |
for key, value in state_dict.items(): | |
# Check if any of the substrings to ignore are in the key | |
skip = False | |
for substr in substrings_to_ignore: | |
if substr in key: | |
skip = True | |
break | |
if skip: | |
continue | |
if len(tensor_kwargs) > 0: | |
value = value.to(**tensor_kwargs) | |
if prefix_to_remove is not None and key.startswith(prefix_to_remove): | |
key = key[len(prefix_to_remove) :] | |
new_state_dict[key] = value | |
return new_state_dict | |
def obtain_tensor_parallel_state_dict( | |
whole_model_state_dict: dict[str, torch.Tensor], | |
tensor_parallel_size: int, | |
tensor_parallel_rank: int, | |
model_config, | |
target_backend: str = None, | |
) -> dict[str, torch.Tensor]: | |
""" | |
Obtain the tensor parallel state dict shard for the current rank. | |
Args: | |
whole_model_state_dict (dict[str, torch.Tensor]): The complete model state dict. | |
tensor_parallel_size (int): The number of tensor parallel devices. | |
tensor_parallel_rank (int): The rank of the current tensor parallel device. | |
model_config: The model configuration. | |
target_backend (str, optional): The target backend format ('pytorch', 'transformer_engine', or 'huggingface'). If not specified, the source backend will be used. | |
Returns: | |
dict[str, torch.Tensor]: The updated state dict shard for the current tensor parallel rank. | |
""" | |
new_state_dict_shard = {} | |
whole_model_state_dict = process_state_dict(whole_model_state_dict) | |
source_backend = identify_checkpoint_backend(whole_model_state_dict) | |
if source_backend != "pytorch": | |
# Convert the checkpoint to PyTorch backend for checkpoint sharding | |
whole_model_state_dict = maybe_convert_checkpoint_to_backend( | |
whole_model_state_dict, target_backend="pytorch", model_config=model_config, source_backend=source_backend | |
) | |
n_heads = model_config["n_heads"] | |
n_kv_heads = model_config["n_kv_heads"] | |
dim = model_config["dim"] | |
context_dim = model_config["context_dim"] | |
for key, value in whole_model_state_dict.items(): | |
prefix = "model." if key.startswith("model.") else "" # LLM's model prefix | |
prefix = "transformer." if key.startswith("transformer.") else prefix # VIT's model prefix | |
key = key.replace(prefix, "") | |
if key.startswith("layers."): | |
layer_index = int(key.split("layers.")[1].split(".")[0]) | |
if layer_index >= model_config["n_layers"]: | |
log.warning( | |
f"Layer index {layer_index} is greater than the number of layers {model_config['n_layers']}. Skipping this layer." | |
) | |
continue | |
if ".attention.wq.weight" in key or "cross_attention.wq.weight" in key: | |
value = torch.chunk(value.view(n_heads, -1, dim), tensor_parallel_size, dim=0)[tensor_parallel_rank] | |
value = value.reshape(-1, dim) | |
elif ".attention.wk.weight" in key or ".attention.wv.weight" in key: | |
value = torch.chunk(value.view(n_kv_heads, -1, dim), tensor_parallel_size, dim=0)[tensor_parallel_rank] | |
value = value.reshape(-1, dim) | |
elif "cross_attention.wk.weight" in key or "cross_attention.wv.weight" in key: | |
assert context_dim is not None | |
value = torch.chunk(value.view(n_kv_heads, -1, context_dim), tensor_parallel_size, dim=0)[ | |
tensor_parallel_rank | |
] | |
value = value.reshape(-1, context_dim) | |
elif "feed_forward.w1.weight" in key or "feed_forward.w3.weight" in key or "medusa_head" in key: | |
value = torch.chunk(value, tensor_parallel_size, dim=0)[tensor_parallel_rank] | |
elif "feed_forward.w2.weight" in key or ".attention.wo.weight" in key or "cross_attention.wo.weight" in key: | |
value = torch.chunk(value, tensor_parallel_size, dim=1)[tensor_parallel_rank] | |
else: | |
# Handle non-layer weights | |
if key == "tok_embeddings.weight" or key == "output.weight" or "medusa_head" in key: | |
value = torch.chunk(value, tensor_parallel_size, dim=0)[tensor_parallel_rank] | |
new_state_dict_shard[prefix + key] = value | |
if target_backend is None: | |
target_backend = source_backend | |
new_state_dict_shard = maybe_convert_checkpoint_to_backend( | |
new_state_dict_shard, | |
target_backend=target_backend, | |
model_config=model_config, | |
is_tensor_parallel_shard=True, | |
tensor_parallel_size=tensor_parallel_size, | |
) | |
return new_state_dict_shard | |
def merge_tensor_parallel_state_dicts( | |
state_dict_shards: list[dict[str, torch.Tensor]], | |
model_config, | |
target_backend: str = None, | |
) -> dict[str, torch.Tensor]: | |
""" | |
Merge tensor parallel state dict shards into a whole model state dict. | |
Args: | |
state_dict_shards (List[Dict[str, torch.Tensor]]): The list of state dict shards to merge. | |
model_config: The model configuration. | |
target_backend (str, optional): The target backend format ('pytorch', 'transformer_engine', or 'huggingface'). If not specified, the source backend will be used. | |
Returns: | |
Dict[str, torch.Tensor]: The merged state dict. | |
""" | |
state_dict_shards = [process_state_dict(shard, device="cpu") for shard in state_dict_shards] | |
tensor_parallel_size = len(state_dict_shards) | |
source_backend = identify_checkpoint_backend(state_dict_shards[0]) | |
if source_backend != "pytorch": | |
log.critical(f"Converting from {source_backend} to PyTorch backend for tensor parallel checkpoint merging.") | |
state_dict_shards = [ | |
maybe_convert_checkpoint_to_backend( | |
shard, | |
target_backend="pytorch", | |
model_config=model_config, | |
source_backend=source_backend, | |
is_tensor_parallel_shard=True, | |
tensor_parallel_size=tensor_parallel_size, | |
) | |
for shard in state_dict_shards | |
] | |
n_heads = model_config["n_heads"] | |
n_kv_heads = model_config["n_kv_heads"] | |
n_local_heads = n_heads // tensor_parallel_size | |
n_local_kv_heads = n_kv_heads // tensor_parallel_size | |
dim = model_config["dim"] | |
context_dim = model_config["context_dim"] | |
head_dim = model_config["head_dim"] | |
if head_dim is None: | |
head_dim = model_config["dim"] // model_config["n_heads"] | |
query_dim = head_dim * n_heads | |
key_value_dim = head_dim * n_kv_heads | |
merged_state_dict = {} | |
for key in state_dict_shards[0].keys(): | |
prefix = "model." if key.startswith("model.") else "" | |
key_without_prefix = key[len(prefix) :] | |
if key_without_prefix.startswith("layers."): | |
layer_index = int(key_without_prefix.split("layers.")[1].split(".")[0]) | |
if layer_index >= model_config["n_layers"]: | |
log.warning( | |
f"Layer index {layer_index} is greater than the number of layers {model_config['n_layers']}. Skipping this layer." | |
) | |
continue | |
if key_without_prefix == "tok_embeddings.weight" or key_without_prefix == "output.weight": | |
merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=0) | |
elif ".attention.wq.weight" in key or "cross_attention.wq.weight" in key: | |
chunks = [shard[key].view(n_local_heads, head_dim, dim) for shard in state_dict_shards] | |
merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(query_dim, dim) | |
elif ".attention.wk.weight" in key or ".attention.wv.weight" in key: | |
chunks = [shard[key].view(n_local_kv_heads, head_dim, dim) for shard in state_dict_shards] | |
merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(key_value_dim, dim) | |
elif "cross_attention.wk.weight" in key or "cross_attention.wv.weight" in key: | |
chunks = [shard[key].view(n_local_kv_heads, head_dim, context_dim) for shard in state_dict_shards] | |
merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(key_value_dim, context_dim) | |
elif "feed_forward.w1.weight" in key or "feed_forward.w3.weight" in key or "medusa_head" in key: | |
merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=0) | |
elif "feed_forward.w2.weight" in key or ".attention.wo.weight" in key or "cross_attention.wo.weight" in key: | |
merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=1) | |
else: | |
avg_tensor = torch.stack([shard[key] for shard in state_dict_shards]).mean(dim=0) | |
# make sure shard-0 is close to the average tensor | |
assert torch.allclose(state_dict_shards[0][key], avg_tensor, atol=5e-2, rtol=0.1), ( | |
f"Shard-0 tensor {key} is not close to the average tensor. " | |
f"Max diff: {torch.max(torch.abs(state_dict_shards[0][key] - avg_tensor))}, " | |
) | |
merged_state_dict[key] = avg_tensor | |
assert "norm" in key, f"Assumed the key {key} is a norm layer, which should be the same across shards." | |
if target_backend is None: | |
target_backend = source_backend | |
return maybe_convert_checkpoint_to_backend( | |
merged_state_dict, target_backend=target_backend, model_config=model_config | |
) | |
def te_to_pytorch_state_dict( | |
te_state_dict: Dict[str, torch.Tensor], model_config, tensor_parallel_size: int = 1 | |
) -> Dict[str, torch.Tensor]: | |
""" | |
Convert a TransformerEngine state dict to PyTorch state dict | |
Args: | |
te_state_dict (Mapping[str, torch.Tensor]): The TransformerEngine state dict | |
model_config: The model configuration | |
tensor_parallel_size (int): The tensor parallel size. Defaults to 1 (i.e., not a tensor parallel shard). | |
Returns: | |
Mapping[str, torch.Tensor]: The PyTorch state dict | |
""" | |
if hasattr(model_config, "asdict"): | |
model_config = model_config.asdict() | |
pytorch_state_dict = {} | |
replacement_rules = [ | |
# Self-attention modules | |
(".self_attention.layernorm_qkv.layer_norm_weight", ".attention_norm.weight"), | |
(".self_attention.layernorm_qkv.query_weight", ".attention.wq.weight"), | |
(".self_attention.layernorm_qkv.key_weight", ".attention.wk.weight"), | |
(".self_attention.layernorm_qkv.value_weight", ".attention.wv.weight"), | |
(".self_attention.proj.weight", ".attention.wo.weight"), | |
(".self_attention.", ".attention."), # Handle the rest modules such as q_norm and k_norm | |
# MLP modules | |
(".layernorm_mlp.layer_norm_weight", ".ffn_norm.weight"), | |
(".layernorm_mlp.fc2_weight", ".feed_forward.w2.weight"), | |
# Cross-attention modules | |
(".inter_attention.layernorm_query.query_weight", ".cross_attention.wq.weight"), | |
(".inter_attention.key_value.key_weight", ".cross_attention.wk.weight"), | |
(".inter_attention.key_value.value_weight", ".cross_attention.wv.weight"), | |
(".inter_attention.proj.weight", ".cross_attention.wo.weight"), | |
(".inter_attention.layernorm_query.layer_norm_weight", ".cross_attention_norm.weight"), | |
(".inter_attention.", ".cross_attention."), # Handle the rest modules such as q_norm and k_norm | |
] | |
head_dim = model_config["head_dim"] | |
if head_dim is None: | |
head_dim = model_config["dim"] // model_config["n_heads"] | |
for old_key, value in te_state_dict.items(): | |
new_key = old_key | |
for old_substr, new_substr in replacement_rules: | |
if old_substr in new_key: | |
new_key = new_key.replace(old_substr, new_substr) | |
break | |
# Handle the fused w1 and w3 case | |
if "layernorm_mlp.fc1_weight" in old_key: | |
fused_weight = value | |
split_point = fused_weight.shape[0] // 2 | |
w1_weight = fused_weight[:split_point] | |
w3_weight = fused_weight[split_point:] | |
w1_key = new_key.replace("layernorm_mlp.fc1_weight", "feed_forward.w1.weight") | |
w3_key = new_key.replace("layernorm_mlp.fc1_weight", "feed_forward.w3.weight") | |
pytorch_state_dict[w1_key] = w1_weight | |
pytorch_state_dict[w3_key] = w3_weight | |
else: | |
if model_config["pytorch_rope_version"] == "v1": | |
# If the model use qk normalization, we will use the same PyTorch RoPE operations as the TE version. | |
# Thus, we do not need to permute the weights. | |
if "query_weight" in old_key: | |
value = inverse_permute_weight( | |
value, | |
n_heads=model_config["n_heads"] // tensor_parallel_size, | |
dim1=head_dim * model_config["n_heads"] // tensor_parallel_size, | |
dim2=model_config["dim"], | |
) | |
elif "key_weight" in old_key: | |
value = inverse_permute_weight( | |
value, | |
n_heads=model_config["n_kv_heads"] // tensor_parallel_size, | |
dim1=head_dim * model_config["n_kv_heads"] // tensor_parallel_size, | |
dim2=model_config["context_dim"] if "inter_attention" in old_key else model_config["dim"], | |
) | |
pytorch_state_dict[new_key] = value | |
return pytorch_state_dict | |
def pytorch_to_te_state_dict( | |
pytorch_state_dict: Dict[str, torch.Tensor], model_config, tensor_parallel_size: int = 1 | |
) -> Dict[str, torch.Tensor]: | |
""" | |
Convert a PyTorch state dict to TransformerEngine state dict | |
Args: | |
pytorch_state_dict (Mapping[str, torch.Tensor]): The PyTorch state dict | |
model_config: The model configuration | |
tensor_parallel_size (int): The tensor parallel size. Defaults to 1 (i.e., not a tensor parallel shard). | |
Returns: | |
Mapping[str, torch.Tensor]: The TransformerEngine | |
""" | |
if hasattr(model_config, "asdict"): | |
model_config = model_config.asdict() | |
te_state_dict = {} | |
replacement_rules = [ | |
# Self-attention modules | |
(".attention_norm.weight", ".self_attention.layernorm_qkv.layer_norm_weight"), | |
(".attention.wq.weight", ".self_attention.layernorm_qkv.query_weight"), | |
(".attention.wk.weight", ".self_attention.layernorm_qkv.key_weight"), | |
(".attention.wv.weight", ".self_attention.layernorm_qkv.value_weight"), | |
(".attention.wo.weight", ".self_attention.proj.weight"), | |
(".attention.", ".self_attention."), | |
# MLP modules | |
(".ffn_norm.weight", ".layernorm_mlp.layer_norm_weight"), | |
(".feed_forward.w2.weight", ".layernorm_mlp.fc2_weight"), | |
# Cross-attention modules | |
(".cross_attention_norm.weight", ".inter_attention.layernorm_query.layer_norm_weight"), | |
(".cross_attention.wq.weight", ".inter_attention.layernorm_query.query_weight"), | |
(".cross_attention.wk.weight", ".inter_attention.key_value.key_weight"), | |
(".cross_attention.wv.weight", ".inter_attention.key_value.value_weight"), | |
(".cross_attention.wo.weight", ".inter_attention.proj.weight"), | |
(".cross_attention.", ".inter_attention."), | |
] | |
head_dim = model_config["head_dim"] | |
if head_dim is None: | |
head_dim = model_config["dim"] // model_config["n_heads"] | |
for old_key, value in pytorch_state_dict.items(): | |
new_key = old_key | |
for new_substr, old_substr in replacement_rules: | |
if new_substr in new_key: | |
new_key = new_key.replace(new_substr, old_substr) | |
break | |
# Handle the split w1 and w3 case | |
if "feed_forward.w1.weight" in old_key: | |
w1_weight = value | |
w3_key = old_key.replace("feed_forward.w1.weight", "feed_forward.w3.weight") | |
if w3_key in pytorch_state_dict: | |
w3_weight = pytorch_state_dict[w3_key] | |
fused_weight = torch.cat([w1_weight, w3_weight], dim=0) | |
new_key = new_key.replace("feed_forward.w1.weight", "layernorm_mlp.fc1_weight") | |
te_state_dict[new_key] = fused_weight | |
else: | |
te_state_dict[new_key] = value | |
elif "feed_forward.w3.weight" in old_key: | |
# Skip w3 weights as they're handled with w1 | |
continue | |
else: | |
if model_config["pytorch_rope_version"] == "v1": | |
# If the model use qk normalization, we will use the same PyTorch RoPE operations as the TE version. | |
# Thus, we do not need to permute the weights. | |
if "attention.wq" in old_key: | |
value = permute_weight( | |
value, | |
n_heads=model_config["n_heads"] // tensor_parallel_size, | |
dim1=head_dim * model_config["n_heads"] // tensor_parallel_size, | |
dim2=model_config["dim"], | |
) | |
elif "attention.wk" in old_key: | |
value = permute_weight( | |
value, | |
n_heads=model_config["n_kv_heads"] // tensor_parallel_size, | |
dim1=head_dim * model_config["n_kv_heads"] // tensor_parallel_size, | |
dim2=model_config["context_dim"] if "cross_attention" in old_key else model_config["dim"], | |
) | |
te_state_dict[new_key] = value | |
return te_state_dict | |
def permute_weight(w: torch.Tensor, n_heads: int, dim1: int, dim2: int) -> torch.Tensor: | |
""" | |
Helper function for converting checkpoints from PyTorch to TransformerEngine | |
Permute the query weight or key weight of each attention layer | |
Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py | |
Args: | |
w (torch.Tensor): The weight tensor to permute | |
n_heads (int): The number of attention heads | |
dim1 (int): The first dimension of the weight tensor | |
dim2 (int): The second dimension of the weight tensor | |
Returns: | |
torch.Tensor: The permuted weight tensor | |
""" | |
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) | |
def inverse_permute_weight(w: torch.Tensor, n_heads: int, dim1: int, dim2: int) -> torch.Tensor: | |
""" | |
Helper function for converting checkpoints from TransformerEngine to PyTorch | |
Permute the query weight or key weight of each attention layer | |
Args: | |
w (torch.Tensor): The weight tensor to permute | |
n_heads (int): The number of attention heads | |
dim1 (int): The first dimension of the weight tensor | |
dim2 (int): The second dimension of the weight tensor | |
Returns: | |
torch.Tensor: The permuted weight tensor | |
""" | |
return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) | |
def pytorch_to_hf_state_dict( | |
state_dict: Dict[str, torch.Tensor], model_config: Dict[str, Any], tensor_parallel_size: int = 1 | |
) -> Dict[str, torch.Tensor]: | |
""" | |
Convert a PyTorch state dict to HuggingFace format for LLM models. | |
Args: | |
state_dict (Mapping[str, torch.Tensor]): | |
The original PyTorch model's state dictionary. | |
This is a mapping where keys are layer names and values are the corresponding PyTorch tensors | |
containing the model weights. | |
model_config (Mapping[str, Any]): | |
The configuration of the model. This dictionary contains parameters such as: | |
- n_layers: (int) The number of transformer layers. | |
- n_heads: (int) The number of attention heads. | |
- dim: (int) The hidden size of the model. | |
- n_kv_heads: (int, optional) The number of key-value heads for multi-query attention. | |
Returns: | |
Mapping[str, torch.Tensor]: | |
The converted HuggingFace state dictionary. This dictionary maps HuggingFace transformer-compatible | |
layer names to the corresponding model weights. | |
""" | |
not_supported_key_substrings = ["cross_attention", "q_norm", "k_norm"] | |
for key in state_dict.keys(): | |
if any(substr in key for substr in not_supported_key_substrings): | |
raise ValueError(f"Key {key} is not supported in HuggingFace format.") | |
assert tensor_parallel_size == 1, "Tensor parallel size > 1 is not supported for HuggingFace model export." | |
hf_state_dict = {} | |
n_layers = model_config["n_layers"] | |
n_heads = model_config["n_heads"] | |
dim = model_config["dim"] | |
head_dim = model_config["head_dim"] | |
if head_dim is None: | |
head_dim = model_config["dim"] // model_config["n_heads"] | |
num_key_value_heads = model_config.get("n_kv_heads", n_heads) | |
key_value_dim = head_dim * num_key_value_heads | |
for layer_i in range(n_layers): | |
pt_prefix = f"layers.{layer_i}." | |
hf_prefix = f"model.layers.{layer_i}." | |
wq = state_dict[f"{pt_prefix}attention.wq.weight"] | |
wk = state_dict[f"{pt_prefix}attention.wk.weight"] | |
if model_config["pytorch_rope_version"] == "v1": | |
wq = permute_weight( | |
wq, | |
n_heads=n_heads, | |
dim1=dim, | |
dim2=dim, | |
) | |
wk = permute_weight( | |
wk, | |
n_heads=num_key_value_heads, | |
dim1=key_value_dim, | |
dim2=dim, | |
) | |
hf_state_dict[f"{hf_prefix}self_attn.q_proj.weight"] = wq | |
hf_state_dict[f"{hf_prefix}self_attn.k_proj.weight"] = wk | |
hf_state_dict[f"{hf_prefix}self_attn.v_proj.weight"] = state_dict[f"{pt_prefix}attention.wv.weight"] | |
hf_state_dict[f"{hf_prefix}self_attn.o_proj.weight"] = state_dict[f"{pt_prefix}attention.wo.weight"] | |
hf_state_dict[f"{hf_prefix}mlp.gate_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w1.weight"] | |
hf_state_dict[f"{hf_prefix}mlp.down_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w2.weight"] | |
hf_state_dict[f"{hf_prefix}mlp.up_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w3.weight"] | |
hf_state_dict[f"{hf_prefix}input_layernorm.weight"] = state_dict[f"{pt_prefix}attention_norm.weight"] | |
hf_state_dict[f"{hf_prefix}post_attention_layernorm.weight"] = state_dict[f"{pt_prefix}ffn_norm.weight"] | |
# Add non-layer weights | |
hf_state_dict["model.embed_tokens.weight"] = state_dict["tok_embeddings.weight"] | |
hf_state_dict["model.norm.weight"] = state_dict["norm.weight"] | |
hf_state_dict["lm_head.weight"] = state_dict["output.weight"] | |
return hf_state_dict | |
def maybe_convert_checkpoint_to_backend( | |
state_dict: Dict[str, torch.Tensor], | |
target_backend: str, | |
model_config, | |
source_backend: str = None, | |
is_tensor_parallel_shard: bool = False, | |
tensor_parallel_size: int = None, | |
): | |
""" | |
Identify the backend of the checkpoint and convert to the target backend if necessary. | |
This function checks the current backend of the state_dict and converts it to the target backend | |
if they don't match. It supports conversions between PyTorch, TransformerEngine, and HuggingFace backends. | |
Args: | |
state_dict (Dict[str, torch.Tensor]): The model state dictionary to convert. | |
target_backend (str): The desired backend format ('pytorch', 'transformer_engine', or 'huggingface'). | |
model_config: Configuration of the model, used in conversion process. | |
source_backend (str, optional): The current backend of the state_dict. If not specified, the function will identify the backend. | |
is_tensor_parallel_shard (bool, optional): Whether the state_dict is a tensor parallel shard. Defaults to False. | |
tensor_parallel_size (int, optional): The tensor parallel size. If not specified, the model_config will be modified. | |
Returns: | |
Dict[str, torch.Tensor]: The converted state dictionary in the target backend format. | |
Raises: | |
ValueError: If the conversion between the identified backend and target backend is not supported. | |
""" | |
# Identify the current backend of the checkpoint | |
state_dict = process_state_dict(state_dict) # Remove unnecessary keys | |
if source_backend is None: | |
source_backend = identify_checkpoint_backend(state_dict) | |
if source_backend == target_backend: | |
return state_dict | |
else: | |
if tensor_parallel_size is None: | |
tensor_parallel_size = model_config["tensor_parallel_size"] if is_tensor_parallel_shard else 1 | |
# Convert to target backend | |
if source_backend == "pytorch" and target_backend == "transformer_engine": | |
return pytorch_to_te_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size) | |
elif source_backend == "transformer_engine" and target_backend == "pytorch": | |
return te_to_pytorch_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size) | |
elif source_backend == "pytorch" and target_backend == "huggingface": | |
return pytorch_to_hf_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size) | |
else: | |
raise ValueError(f"Conversion from {source_backend} to {target_backend} is not supported.") | |