roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
import torch
from cosmos_predict1.utils import log
# Substrings to ignore when processing state dicts
substrings_to_ignore = [
"_extra_state", # Extra states (BytesIO type) added by TransformerEngine for FP8 handling
]
def identify_checkpoint_backend(state_dict: dict[str, torch.Tensor]) -> str:
"""
Identify the backend of the checkpoint (PyTorch or TransformerEngine)
Args:
state_dict (dict[str, torch.Tensor]): The state dict to check
Returns:
str: The backend of the checkpoint
"""
for key in state_dict.keys():
if "self_attention.layernorm_qkv.query_weight" in key:
return "transformer_engine"
elif "attention.wq.weight" in key:
return "pytorch"
raise ValueError("Could not identify the backend of the checkpoint")
def get_partial_state_dict(
state_dict: dict[str, torch.Tensor],
prefix: str,
) -> dict[str, torch.Tensor]:
"""
Get a partial state dict with keys starting with the given prefix
"""
return {k: v for k, v in state_dict.items() if k.startswith(prefix)}
def process_state_dict(
state_dict: dict[str, torch.Tensor],
device: str = None,
dtype: torch.dtype = None,
prefix_to_remove: Optional[str] = None,
) -> dict[str, torch.Tensor]:
"""
- Remove items with substring "_extra_state" in keys (TransformerEngine adds these for FP8)
- Move tensors to specified device and dtype if provided
Args:
state_dict (dict[str, torch.Tensor]): The state dict to process
device (str, optional): The device to move tensors to. Defaults to None.
dtype (torch.dtype, optional): The dtype to move tensors to. Defaults to None.
prefix_to_remove (str, optional): The prefix to remove from the keys of the state dict. Defaults to None.
Returns:
dict[str, torch.Tensor]: The processed state dict
"""
new_state_dict = {}
tensor_kwargs = {}
if device is not None:
tensor_kwargs["device"] = device
if dtype is not None:
tensor_kwargs["dtype"] = dtype
for key, value in state_dict.items():
# Check if any of the substrings to ignore are in the key
skip = False
for substr in substrings_to_ignore:
if substr in key:
skip = True
break
if skip:
continue
if len(tensor_kwargs) > 0:
value = value.to(**tensor_kwargs)
if prefix_to_remove is not None and key.startswith(prefix_to_remove):
key = key[len(prefix_to_remove) :]
new_state_dict[key] = value
return new_state_dict
def obtain_tensor_parallel_state_dict(
whole_model_state_dict: dict[str, torch.Tensor],
tensor_parallel_size: int,
tensor_parallel_rank: int,
model_config,
target_backend: str = None,
) -> dict[str, torch.Tensor]:
"""
Obtain the tensor parallel state dict shard for the current rank.
Args:
whole_model_state_dict (dict[str, torch.Tensor]): The complete model state dict.
tensor_parallel_size (int): The number of tensor parallel devices.
tensor_parallel_rank (int): The rank of the current tensor parallel device.
model_config: The model configuration.
target_backend (str, optional): The target backend format ('pytorch', 'transformer_engine', or 'huggingface'). If not specified, the source backend will be used.
Returns:
dict[str, torch.Tensor]: The updated state dict shard for the current tensor parallel rank.
"""
new_state_dict_shard = {}
whole_model_state_dict = process_state_dict(whole_model_state_dict)
source_backend = identify_checkpoint_backend(whole_model_state_dict)
if source_backend != "pytorch":
# Convert the checkpoint to PyTorch backend for checkpoint sharding
whole_model_state_dict = maybe_convert_checkpoint_to_backend(
whole_model_state_dict, target_backend="pytorch", model_config=model_config, source_backend=source_backend
)
n_heads = model_config["n_heads"]
n_kv_heads = model_config["n_kv_heads"]
dim = model_config["dim"]
context_dim = model_config["context_dim"]
for key, value in whole_model_state_dict.items():
prefix = "model." if key.startswith("model.") else "" # LLM's model prefix
prefix = "transformer." if key.startswith("transformer.") else prefix # VIT's model prefix
key = key.replace(prefix, "")
if key.startswith("layers."):
layer_index = int(key.split("layers.")[1].split(".")[0])
if layer_index >= model_config["n_layers"]:
log.warning(
f"Layer index {layer_index} is greater than the number of layers {model_config['n_layers']}. Skipping this layer."
)
continue
if ".attention.wq.weight" in key or "cross_attention.wq.weight" in key:
value = torch.chunk(value.view(n_heads, -1, dim), tensor_parallel_size, dim=0)[tensor_parallel_rank]
value = value.reshape(-1, dim)
elif ".attention.wk.weight" in key or ".attention.wv.weight" in key:
value = torch.chunk(value.view(n_kv_heads, -1, dim), tensor_parallel_size, dim=0)[tensor_parallel_rank]
value = value.reshape(-1, dim)
elif "cross_attention.wk.weight" in key or "cross_attention.wv.weight" in key:
assert context_dim is not None
value = torch.chunk(value.view(n_kv_heads, -1, context_dim), tensor_parallel_size, dim=0)[
tensor_parallel_rank
]
value = value.reshape(-1, context_dim)
elif "feed_forward.w1.weight" in key or "feed_forward.w3.weight" in key or "medusa_head" in key:
value = torch.chunk(value, tensor_parallel_size, dim=0)[tensor_parallel_rank]
elif "feed_forward.w2.weight" in key or ".attention.wo.weight" in key or "cross_attention.wo.weight" in key:
value = torch.chunk(value, tensor_parallel_size, dim=1)[tensor_parallel_rank]
else:
# Handle non-layer weights
if key == "tok_embeddings.weight" or key == "output.weight" or "medusa_head" in key:
value = torch.chunk(value, tensor_parallel_size, dim=0)[tensor_parallel_rank]
new_state_dict_shard[prefix + key] = value
if target_backend is None:
target_backend = source_backend
new_state_dict_shard = maybe_convert_checkpoint_to_backend(
new_state_dict_shard,
target_backend=target_backend,
model_config=model_config,
is_tensor_parallel_shard=True,
tensor_parallel_size=tensor_parallel_size,
)
return new_state_dict_shard
def merge_tensor_parallel_state_dicts(
state_dict_shards: list[dict[str, torch.Tensor]],
model_config,
target_backend: str = None,
) -> dict[str, torch.Tensor]:
"""
Merge tensor parallel state dict shards into a whole model state dict.
Args:
state_dict_shards (List[Dict[str, torch.Tensor]]): The list of state dict shards to merge.
model_config: The model configuration.
target_backend (str, optional): The target backend format ('pytorch', 'transformer_engine', or 'huggingface'). If not specified, the source backend will be used.
Returns:
Dict[str, torch.Tensor]: The merged state dict.
"""
state_dict_shards = [process_state_dict(shard, device="cpu") for shard in state_dict_shards]
tensor_parallel_size = len(state_dict_shards)
source_backend = identify_checkpoint_backend(state_dict_shards[0])
if source_backend != "pytorch":
log.critical(f"Converting from {source_backend} to PyTorch backend for tensor parallel checkpoint merging.")
state_dict_shards = [
maybe_convert_checkpoint_to_backend(
shard,
target_backend="pytorch",
model_config=model_config,
source_backend=source_backend,
is_tensor_parallel_shard=True,
tensor_parallel_size=tensor_parallel_size,
)
for shard in state_dict_shards
]
n_heads = model_config["n_heads"]
n_kv_heads = model_config["n_kv_heads"]
n_local_heads = n_heads // tensor_parallel_size
n_local_kv_heads = n_kv_heads // tensor_parallel_size
dim = model_config["dim"]
context_dim = model_config["context_dim"]
head_dim = model_config["head_dim"]
if head_dim is None:
head_dim = model_config["dim"] // model_config["n_heads"]
query_dim = head_dim * n_heads
key_value_dim = head_dim * n_kv_heads
merged_state_dict = {}
for key in state_dict_shards[0].keys():
prefix = "model." if key.startswith("model.") else ""
key_without_prefix = key[len(prefix) :]
if key_without_prefix.startswith("layers."):
layer_index = int(key_without_prefix.split("layers.")[1].split(".")[0])
if layer_index >= model_config["n_layers"]:
log.warning(
f"Layer index {layer_index} is greater than the number of layers {model_config['n_layers']}. Skipping this layer."
)
continue
if key_without_prefix == "tok_embeddings.weight" or key_without_prefix == "output.weight":
merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=0)
elif ".attention.wq.weight" in key or "cross_attention.wq.weight" in key:
chunks = [shard[key].view(n_local_heads, head_dim, dim) for shard in state_dict_shards]
merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(query_dim, dim)
elif ".attention.wk.weight" in key or ".attention.wv.weight" in key:
chunks = [shard[key].view(n_local_kv_heads, head_dim, dim) for shard in state_dict_shards]
merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(key_value_dim, dim)
elif "cross_attention.wk.weight" in key or "cross_attention.wv.weight" in key:
chunks = [shard[key].view(n_local_kv_heads, head_dim, context_dim) for shard in state_dict_shards]
merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(key_value_dim, context_dim)
elif "feed_forward.w1.weight" in key or "feed_forward.w3.weight" in key or "medusa_head" in key:
merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=0)
elif "feed_forward.w2.weight" in key or ".attention.wo.weight" in key or "cross_attention.wo.weight" in key:
merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=1)
else:
avg_tensor = torch.stack([shard[key] for shard in state_dict_shards]).mean(dim=0)
# make sure shard-0 is close to the average tensor
assert torch.allclose(state_dict_shards[0][key], avg_tensor, atol=5e-2, rtol=0.1), (
f"Shard-0 tensor {key} is not close to the average tensor. "
f"Max diff: {torch.max(torch.abs(state_dict_shards[0][key] - avg_tensor))}, "
)
merged_state_dict[key] = avg_tensor
assert "norm" in key, f"Assumed the key {key} is a norm layer, which should be the same across shards."
if target_backend is None:
target_backend = source_backend
return maybe_convert_checkpoint_to_backend(
merged_state_dict, target_backend=target_backend, model_config=model_config
)
def te_to_pytorch_state_dict(
te_state_dict: Dict[str, torch.Tensor], model_config, tensor_parallel_size: int = 1
) -> Dict[str, torch.Tensor]:
"""
Convert a TransformerEngine state dict to PyTorch state dict
Args:
te_state_dict (Mapping[str, torch.Tensor]): The TransformerEngine state dict
model_config: The model configuration
tensor_parallel_size (int): The tensor parallel size. Defaults to 1 (i.e., not a tensor parallel shard).
Returns:
Mapping[str, torch.Tensor]: The PyTorch state dict
"""
if hasattr(model_config, "asdict"):
model_config = model_config.asdict()
pytorch_state_dict = {}
replacement_rules = [
# Self-attention modules
(".self_attention.layernorm_qkv.layer_norm_weight", ".attention_norm.weight"),
(".self_attention.layernorm_qkv.query_weight", ".attention.wq.weight"),
(".self_attention.layernorm_qkv.key_weight", ".attention.wk.weight"),
(".self_attention.layernorm_qkv.value_weight", ".attention.wv.weight"),
(".self_attention.proj.weight", ".attention.wo.weight"),
(".self_attention.", ".attention."), # Handle the rest modules such as q_norm and k_norm
# MLP modules
(".layernorm_mlp.layer_norm_weight", ".ffn_norm.weight"),
(".layernorm_mlp.fc2_weight", ".feed_forward.w2.weight"),
# Cross-attention modules
(".inter_attention.layernorm_query.query_weight", ".cross_attention.wq.weight"),
(".inter_attention.key_value.key_weight", ".cross_attention.wk.weight"),
(".inter_attention.key_value.value_weight", ".cross_attention.wv.weight"),
(".inter_attention.proj.weight", ".cross_attention.wo.weight"),
(".inter_attention.layernorm_query.layer_norm_weight", ".cross_attention_norm.weight"),
(".inter_attention.", ".cross_attention."), # Handle the rest modules such as q_norm and k_norm
]
head_dim = model_config["head_dim"]
if head_dim is None:
head_dim = model_config["dim"] // model_config["n_heads"]
for old_key, value in te_state_dict.items():
new_key = old_key
for old_substr, new_substr in replacement_rules:
if old_substr in new_key:
new_key = new_key.replace(old_substr, new_substr)
break
# Handle the fused w1 and w3 case
if "layernorm_mlp.fc1_weight" in old_key:
fused_weight = value
split_point = fused_weight.shape[0] // 2
w1_weight = fused_weight[:split_point]
w3_weight = fused_weight[split_point:]
w1_key = new_key.replace("layernorm_mlp.fc1_weight", "feed_forward.w1.weight")
w3_key = new_key.replace("layernorm_mlp.fc1_weight", "feed_forward.w3.weight")
pytorch_state_dict[w1_key] = w1_weight
pytorch_state_dict[w3_key] = w3_weight
else:
if model_config["pytorch_rope_version"] == "v1":
# If the model use qk normalization, we will use the same PyTorch RoPE operations as the TE version.
# Thus, we do not need to permute the weights.
if "query_weight" in old_key:
value = inverse_permute_weight(
value,
n_heads=model_config["n_heads"] // tensor_parallel_size,
dim1=head_dim * model_config["n_heads"] // tensor_parallel_size,
dim2=model_config["dim"],
)
elif "key_weight" in old_key:
value = inverse_permute_weight(
value,
n_heads=model_config["n_kv_heads"] // tensor_parallel_size,
dim1=head_dim * model_config["n_kv_heads"] // tensor_parallel_size,
dim2=model_config["context_dim"] if "inter_attention" in old_key else model_config["dim"],
)
pytorch_state_dict[new_key] = value
return pytorch_state_dict
def pytorch_to_te_state_dict(
pytorch_state_dict: Dict[str, torch.Tensor], model_config, tensor_parallel_size: int = 1
) -> Dict[str, torch.Tensor]:
"""
Convert a PyTorch state dict to TransformerEngine state dict
Args:
pytorch_state_dict (Mapping[str, torch.Tensor]): The PyTorch state dict
model_config: The model configuration
tensor_parallel_size (int): The tensor parallel size. Defaults to 1 (i.e., not a tensor parallel shard).
Returns:
Mapping[str, torch.Tensor]: The TransformerEngine
"""
if hasattr(model_config, "asdict"):
model_config = model_config.asdict()
te_state_dict = {}
replacement_rules = [
# Self-attention modules
(".attention_norm.weight", ".self_attention.layernorm_qkv.layer_norm_weight"),
(".attention.wq.weight", ".self_attention.layernorm_qkv.query_weight"),
(".attention.wk.weight", ".self_attention.layernorm_qkv.key_weight"),
(".attention.wv.weight", ".self_attention.layernorm_qkv.value_weight"),
(".attention.wo.weight", ".self_attention.proj.weight"),
(".attention.", ".self_attention."),
# MLP modules
(".ffn_norm.weight", ".layernorm_mlp.layer_norm_weight"),
(".feed_forward.w2.weight", ".layernorm_mlp.fc2_weight"),
# Cross-attention modules
(".cross_attention_norm.weight", ".inter_attention.layernorm_query.layer_norm_weight"),
(".cross_attention.wq.weight", ".inter_attention.layernorm_query.query_weight"),
(".cross_attention.wk.weight", ".inter_attention.key_value.key_weight"),
(".cross_attention.wv.weight", ".inter_attention.key_value.value_weight"),
(".cross_attention.wo.weight", ".inter_attention.proj.weight"),
(".cross_attention.", ".inter_attention."),
]
head_dim = model_config["head_dim"]
if head_dim is None:
head_dim = model_config["dim"] // model_config["n_heads"]
for old_key, value in pytorch_state_dict.items():
new_key = old_key
for new_substr, old_substr in replacement_rules:
if new_substr in new_key:
new_key = new_key.replace(new_substr, old_substr)
break
# Handle the split w1 and w3 case
if "feed_forward.w1.weight" in old_key:
w1_weight = value
w3_key = old_key.replace("feed_forward.w1.weight", "feed_forward.w3.weight")
if w3_key in pytorch_state_dict:
w3_weight = pytorch_state_dict[w3_key]
fused_weight = torch.cat([w1_weight, w3_weight], dim=0)
new_key = new_key.replace("feed_forward.w1.weight", "layernorm_mlp.fc1_weight")
te_state_dict[new_key] = fused_weight
else:
te_state_dict[new_key] = value
elif "feed_forward.w3.weight" in old_key:
# Skip w3 weights as they're handled with w1
continue
else:
if model_config["pytorch_rope_version"] == "v1":
# If the model use qk normalization, we will use the same PyTorch RoPE operations as the TE version.
# Thus, we do not need to permute the weights.
if "attention.wq" in old_key:
value = permute_weight(
value,
n_heads=model_config["n_heads"] // tensor_parallel_size,
dim1=head_dim * model_config["n_heads"] // tensor_parallel_size,
dim2=model_config["dim"],
)
elif "attention.wk" in old_key:
value = permute_weight(
value,
n_heads=model_config["n_kv_heads"] // tensor_parallel_size,
dim1=head_dim * model_config["n_kv_heads"] // tensor_parallel_size,
dim2=model_config["context_dim"] if "cross_attention" in old_key else model_config["dim"],
)
te_state_dict[new_key] = value
return te_state_dict
def permute_weight(w: torch.Tensor, n_heads: int, dim1: int, dim2: int) -> torch.Tensor:
"""
Helper function for converting checkpoints from PyTorch to TransformerEngine
Permute the query weight or key weight of each attention layer
Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
Args:
w (torch.Tensor): The weight tensor to permute
n_heads (int): The number of attention heads
dim1 (int): The first dimension of the weight tensor
dim2 (int): The second dimension of the weight tensor
Returns:
torch.Tensor: The permuted weight tensor
"""
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
def inverse_permute_weight(w: torch.Tensor, n_heads: int, dim1: int, dim2: int) -> torch.Tensor:
"""
Helper function for converting checkpoints from TransformerEngine to PyTorch
Permute the query weight or key weight of each attention layer
Args:
w (torch.Tensor): The weight tensor to permute
n_heads (int): The number of attention heads
dim1 (int): The first dimension of the weight tensor
dim2 (int): The second dimension of the weight tensor
Returns:
torch.Tensor: The permuted weight tensor
"""
return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
def pytorch_to_hf_state_dict(
state_dict: Dict[str, torch.Tensor], model_config: Dict[str, Any], tensor_parallel_size: int = 1
) -> Dict[str, torch.Tensor]:
"""
Convert a PyTorch state dict to HuggingFace format for LLM models.
Args:
state_dict (Mapping[str, torch.Tensor]):
The original PyTorch model's state dictionary.
This is a mapping where keys are layer names and values are the corresponding PyTorch tensors
containing the model weights.
model_config (Mapping[str, Any]):
The configuration of the model. This dictionary contains parameters such as:
- n_layers: (int) The number of transformer layers.
- n_heads: (int) The number of attention heads.
- dim: (int) The hidden size of the model.
- n_kv_heads: (int, optional) The number of key-value heads for multi-query attention.
Returns:
Mapping[str, torch.Tensor]:
The converted HuggingFace state dictionary. This dictionary maps HuggingFace transformer-compatible
layer names to the corresponding model weights.
"""
not_supported_key_substrings = ["cross_attention", "q_norm", "k_norm"]
for key in state_dict.keys():
if any(substr in key for substr in not_supported_key_substrings):
raise ValueError(f"Key {key} is not supported in HuggingFace format.")
assert tensor_parallel_size == 1, "Tensor parallel size > 1 is not supported for HuggingFace model export."
hf_state_dict = {}
n_layers = model_config["n_layers"]
n_heads = model_config["n_heads"]
dim = model_config["dim"]
head_dim = model_config["head_dim"]
if head_dim is None:
head_dim = model_config["dim"] // model_config["n_heads"]
num_key_value_heads = model_config.get("n_kv_heads", n_heads)
key_value_dim = head_dim * num_key_value_heads
for layer_i in range(n_layers):
pt_prefix = f"layers.{layer_i}."
hf_prefix = f"model.layers.{layer_i}."
wq = state_dict[f"{pt_prefix}attention.wq.weight"]
wk = state_dict[f"{pt_prefix}attention.wk.weight"]
if model_config["pytorch_rope_version"] == "v1":
wq = permute_weight(
wq,
n_heads=n_heads,
dim1=dim,
dim2=dim,
)
wk = permute_weight(
wk,
n_heads=num_key_value_heads,
dim1=key_value_dim,
dim2=dim,
)
hf_state_dict[f"{hf_prefix}self_attn.q_proj.weight"] = wq
hf_state_dict[f"{hf_prefix}self_attn.k_proj.weight"] = wk
hf_state_dict[f"{hf_prefix}self_attn.v_proj.weight"] = state_dict[f"{pt_prefix}attention.wv.weight"]
hf_state_dict[f"{hf_prefix}self_attn.o_proj.weight"] = state_dict[f"{pt_prefix}attention.wo.weight"]
hf_state_dict[f"{hf_prefix}mlp.gate_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w1.weight"]
hf_state_dict[f"{hf_prefix}mlp.down_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w2.weight"]
hf_state_dict[f"{hf_prefix}mlp.up_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w3.weight"]
hf_state_dict[f"{hf_prefix}input_layernorm.weight"] = state_dict[f"{pt_prefix}attention_norm.weight"]
hf_state_dict[f"{hf_prefix}post_attention_layernorm.weight"] = state_dict[f"{pt_prefix}ffn_norm.weight"]
# Add non-layer weights
hf_state_dict["model.embed_tokens.weight"] = state_dict["tok_embeddings.weight"]
hf_state_dict["model.norm.weight"] = state_dict["norm.weight"]
hf_state_dict["lm_head.weight"] = state_dict["output.weight"]
return hf_state_dict
def maybe_convert_checkpoint_to_backend(
state_dict: Dict[str, torch.Tensor],
target_backend: str,
model_config,
source_backend: str = None,
is_tensor_parallel_shard: bool = False,
tensor_parallel_size: int = None,
):
"""
Identify the backend of the checkpoint and convert to the target backend if necessary.
This function checks the current backend of the state_dict and converts it to the target backend
if they don't match. It supports conversions between PyTorch, TransformerEngine, and HuggingFace backends.
Args:
state_dict (Dict[str, torch.Tensor]): The model state dictionary to convert.
target_backend (str): The desired backend format ('pytorch', 'transformer_engine', or 'huggingface').
model_config: Configuration of the model, used in conversion process.
source_backend (str, optional): The current backend of the state_dict. If not specified, the function will identify the backend.
is_tensor_parallel_shard (bool, optional): Whether the state_dict is a tensor parallel shard. Defaults to False.
tensor_parallel_size (int, optional): The tensor parallel size. If not specified, the model_config will be modified.
Returns:
Dict[str, torch.Tensor]: The converted state dictionary in the target backend format.
Raises:
ValueError: If the conversion between the identified backend and target backend is not supported.
"""
# Identify the current backend of the checkpoint
state_dict = process_state_dict(state_dict) # Remove unnecessary keys
if source_backend is None:
source_backend = identify_checkpoint_backend(state_dict)
if source_backend == target_backend:
return state_dict
else:
if tensor_parallel_size is None:
tensor_parallel_size = model_config["tensor_parallel_size"] if is_tensor_parallel_shard else 1
# Convert to target backend
if source_backend == "pytorch" and target_backend == "transformer_engine":
return pytorch_to_te_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size)
elif source_backend == "transformer_engine" and target_backend == "pytorch":
return te_to_pytorch_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size)
elif source_backend == "pytorch" and target_backend == "huggingface":
return pytorch_to_hf_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size)
else:
raise ValueError(f"Conversion from {source_backend} to {target_backend} is not supported.")