Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torch.nn as nn | |
def create_norm(norm_type: str, dim: int, eps: float = 1e-6): | |
""" | |
Creates the specified normalization layer based on the norm_type. | |
Adopted from TorchTriton: https://github.com/pytorch/torchtitan/blob/main/torchtitan/cosmos_predict1/norms.py | |
Args: | |
norm_type (str): The type of normalization layer to create. | |
Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm | |
dim (int): The dimension of the normalization layer. | |
eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. | |
Returns: | |
The created normalization layer. | |
Raises: | |
NotImplementedError: If an unknown norm_type is provided. | |
""" | |
norm_type = norm_type.lower() # Normalize to lowercase | |
if norm_type == "layernorm": | |
return nn.LayerNorm(dim, eps=eps, bias=False) | |
elif norm_type == "np_layernorm": | |
return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) | |
elif norm_type == "rmsnorm": | |
return RMSNorm(dim, eps=eps, compile=False) | |
elif norm_type == "compiled_rmsnorm": | |
return RMSNorm(dim, eps=eps, compile=True) | |
elif norm_type == "fused_rmsnorm": | |
raise NotImplementedError("Fused RMSNorm is not supported yet.") | |
else: | |
raise NotImplementedError(f"Unknown norm_type: '{norm_type}'") | |
class RMSNorm(nn.Module): | |
""" | |
Initialize the RMSNorm normalization layer. | |
Reference implementation: https://github.com/pytorch/torchtitan/blob/main/torchtitan/cosmos_predict1/norms.py | |
Args: | |
dim (int): The dimension of the input tensor. | |
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. | |
compile (bool, optional): Whether to compile the forward function. Default is False. | |
Attributes: | |
eps (float): A small value added to the denominator for numerical stability. | |
weight (nn.Parameter): Learnable scaling parameter. | |
""" | |
def __init__(self, dim: int, eps: float = 1e-6, compile: bool = False): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(dim)) | |
self.rmsnorm_fn = torch.compile(self.compute_rmsnorm, fullgraph=True) if compile else self.compute_rmsnorm | |
def compute_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float): | |
def _norm(x, eps): | |
# Computes the root-mean-square norm of the input tensor. | |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) | |
output = _norm(x.float(), eps).type_as(x) | |
return output * weight | |
def forward(self, x: torch.Tensor): | |
return self.rmsnorm_fn(x, self.weight, self.eps) | |
def reset_parameters(self): | |
torch.nn.init.ones_(self.weight) | |