Spaces:
Running
on
Zero
Running
on
Zero
using apex rmsnorm (#57)
Browse files* using apex rmsnorm
* added message for missing apex
* black
* missed a print
---------
Co-authored-by: Srini Iyer <[email protected]>
bytelatent/base_transformer.py
CHANGED
|
@@ -17,6 +17,14 @@ from xformers.ops import AttentionBias, fmha
|
|
| 17 |
from bytelatent import probe
|
| 18 |
from bytelatent.tokenizers.constants import EOS_ID
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
| 21 |
flex_attention_comp = torch.compile(flex_attention)
|
| 22 |
else:
|
|
@@ -294,37 +302,6 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 294 |
return self.freqs_cis[0:seqlen]
|
| 295 |
|
| 296 |
|
| 297 |
-
class RMSNorm(nn.Module):
|
| 298 |
-
"""
|
| 299 |
-
Initialize the RMSNorm normalization layer.
|
| 300 |
-
|
| 301 |
-
Args:
|
| 302 |
-
dim (int): The dimension of the input tensor.
|
| 303 |
-
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
| 304 |
-
|
| 305 |
-
Attributes:
|
| 306 |
-
eps (float): A small value added to the denominator for numerical stability.
|
| 307 |
-
weight (nn.Parameter): Learnable scaling parameter.
|
| 308 |
-
|
| 309 |
-
"""
|
| 310 |
-
|
| 311 |
-
def __init__(self, dim: int, eps: float = 1e-6):
|
| 312 |
-
super().__init__()
|
| 313 |
-
self.eps = eps
|
| 314 |
-
self.weight = nn.Parameter(torch.ones(dim))
|
| 315 |
-
|
| 316 |
-
def _norm(self, x: torch.Tensor):
|
| 317 |
-
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
|
| 318 |
-
|
| 319 |
-
def forward(self, x: torch.Tensor):
|
| 320 |
-
x = probe.log_stats(x, "resid")
|
| 321 |
-
output = self._norm(x.float())
|
| 322 |
-
return (output * self.weight.float()).type_as(x)
|
| 323 |
-
|
| 324 |
-
def reset_parameters(self):
|
| 325 |
-
torch.nn.init.ones_(self.weight) # type: ignore
|
| 326 |
-
|
| 327 |
-
|
| 328 |
def _reshape_for_attn_bias(
|
| 329 |
attn_bias: AttentionBias | None,
|
| 330 |
*tensors: torch.Tensor,
|
|
|
|
| 17 |
from bytelatent import probe
|
| 18 |
from bytelatent.tokenizers.constants import EOS_ID
|
| 19 |
|
| 20 |
+
try:
|
| 21 |
+
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
| 22 |
+
|
| 23 |
+
RMSNorm = FusedRMSNorm
|
| 24 |
+
except (ImportError, ModuleNotFoundError):
|
| 25 |
+
print("Apex not found. Using nn.RMSNorm")
|
| 26 |
+
RMSNorm = nn.RMSNorm
|
| 27 |
+
|
| 28 |
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
| 29 |
flex_attention_comp = torch.compile(flex_attention)
|
| 30 |
else:
|
|
|
|
| 302 |
return self.freqs_cis[0:seqlen]
|
| 303 |
|
| 304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
def _reshape_for_attn_bias(
|
| 306 |
attn_bias: AttentionBias | None,
|
| 307 |
*tensors: torch.Tensor,
|
bytelatent/model/latent_transformer.py
CHANGED
|
@@ -12,12 +12,19 @@ from xformers.ops import AttentionBias
|
|
| 12 |
from bytelatent.base_transformer import (
|
| 13 |
BaseTransformer,
|
| 14 |
BaseTransformerArgs,
|
| 15 |
-
RMSNorm,
|
| 16 |
flex_attention_comp,
|
| 17 |
repeat_kv,
|
| 18 |
)
|
| 19 |
from bytelatent.model.utils import create_causal_mask
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
logger = logging.getLogger()
|
| 22 |
|
| 23 |
|
|
@@ -44,7 +51,7 @@ class CrossAttention(nn.Module):
|
|
| 44 |
self.n_kv_heads = n_kv_heads
|
| 45 |
self.heads_per_group = self.n_heads // self.n_kv_heads
|
| 46 |
|
| 47 |
-
self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps)
|
| 48 |
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
|
| 49 |
|
| 50 |
self.wq = nn.Linear(
|
|
|
|
| 12 |
from bytelatent.base_transformer import (
|
| 13 |
BaseTransformer,
|
| 14 |
BaseTransformerArgs,
|
|
|
|
| 15 |
flex_attention_comp,
|
| 16 |
repeat_kv,
|
| 17 |
)
|
| 18 |
from bytelatent.model.utils import create_causal_mask
|
| 19 |
|
| 20 |
+
try:
|
| 21 |
+
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
| 22 |
+
|
| 23 |
+
RMSNorm = FusedRMSNorm
|
| 24 |
+
except (ImportError, ModuleNotFoundError):
|
| 25 |
+
print("Apex not found. Using nn.RMSNorm")
|
| 26 |
+
RMSNorm = nn.RMSNorm
|
| 27 |
+
|
| 28 |
logger = logging.getLogger()
|
| 29 |
|
| 30 |
|
|
|
|
| 51 |
self.n_kv_heads = n_kv_heads
|
| 52 |
self.heads_per_group = self.n_heads // self.n_kv_heads
|
| 53 |
|
| 54 |
+
self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps)
|
| 55 |
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
|
| 56 |
|
| 57 |
self.wq = nn.Linear(
|
bytelatent/model/local_models.py
CHANGED
|
@@ -14,7 +14,6 @@ from xformers.ops import AttentionBias
|
|
| 14 |
from bytelatent.base_transformer import (
|
| 15 |
BaseTransformerArgs,
|
| 16 |
InitStdFactor,
|
| 17 |
-
RMSNorm,
|
| 18 |
RotaryEmbedding,
|
| 19 |
TransformerBlock,
|
| 20 |
)
|
|
@@ -22,6 +21,14 @@ from bytelatent.model.latent_transformer import CrossAttention
|
|
| 22 |
from bytelatent.model.utils import create_causal_mask, downsample
|
| 23 |
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
logger = logging.getLogger()
|
| 26 |
|
| 27 |
|
|
|
|
| 14 |
from bytelatent.base_transformer import (
|
| 15 |
BaseTransformerArgs,
|
| 16 |
InitStdFactor,
|
|
|
|
| 17 |
RotaryEmbedding,
|
| 18 |
TransformerBlock,
|
| 19 |
)
|
|
|
|
| 21 |
from bytelatent.model.utils import create_causal_mask, downsample
|
| 22 |
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
|
| 23 |
|
| 24 |
+
try:
|
| 25 |
+
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
| 26 |
+
|
| 27 |
+
RMSNorm = FusedRMSNorm
|
| 28 |
+
except (ImportError, ModuleNotFoundError):
|
| 29 |
+
print("Apex not found. Using nn.RMSNorm")
|
| 30 |
+
RMSNorm = nn.RMSNorm
|
| 31 |
+
|
| 32 |
logger = logging.getLogger()
|
| 33 |
|
| 34 |
|
bytelatent/transformer.py
CHANGED
|
@@ -19,11 +19,18 @@ from xformers.ops import AttentionBias, fmha
|
|
| 19 |
from bytelatent.base_transformer import (
|
| 20 |
BaseTransformer,
|
| 21 |
BaseTransformerArgs,
|
| 22 |
-
RMSNorm,
|
| 23 |
cross_entropy,
|
| 24 |
)
|
| 25 |
from bytelatent.model.utils import create_causal_mask
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
| 29 |
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
|
|
|
|
| 19 |
from bytelatent.base_transformer import (
|
| 20 |
BaseTransformer,
|
| 21 |
BaseTransformerArgs,
|
|
|
|
| 22 |
cross_entropy,
|
| 23 |
)
|
| 24 |
from bytelatent.model.utils import create_causal_mask
|
| 25 |
|
| 26 |
+
try:
|
| 27 |
+
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
| 28 |
+
|
| 29 |
+
RMSNorm = FusedRMSNorm
|
| 30 |
+
except (ImportError, ModuleNotFoundError):
|
| 31 |
+
print("Apex not found. Using nn.RMSNorm")
|
| 32 |
+
RMSNorm = nn.RMSNorm
|
| 33 |
+
|
| 34 |
|
| 35 |
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
| 36 |
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
|