danieldk HF Staff commited on
Commit
80b5db1
·
1 Parent(s): 491036b

Add top-level `layer_norm` wrapper with docs

Browse files
build/torch-universal/triton_layer_norm/__init__.py CHANGED
@@ -1,5 +1,114 @@
1
- from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
 
 
 
 
 
 
 
 
2
 
3
  from . import layers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- __all__ = ["layers", "layer_norm_fn", "layer_norm_linear_fn", "rms_norm_fn"]
 
 
 
 
 
 
 
 
1
+ """Triton layer normalization kernels.
2
+
3
+ This kernel implements layers normalization using Triton. This kernel is from
4
+ the `flash-attention <https://github.com/Dao-AILab/flash-attention>`_ project.
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
 
11
  from . import layers
12
+ from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
13
+
14
+
15
+ def layer_norm(
16
+ x: torch.Tensor,
17
+ weight: torch.Tensor,
18
+ bias: torch.Tensor,
19
+ residual: Optional[torch.Tensor] = None,
20
+ x1: Optional[torch.Tensor] = None,
21
+ weight1: Optional[torch.Tensor] = None,
22
+ bias1: Optional[torch.Tensor] = None,
23
+ eps: float = 1e-6,
24
+ dropout_p: float = 0.0,
25
+ rowscale=None,
26
+ prenorm: bool = False,
27
+ residual_in_fp32: bool = False,
28
+ is_rms_norm: bool = False,
29
+ return_dropout_mask: bool = False,
30
+ out: Optional[torch.Tensor] = None,
31
+ residual_out: Optional[torch.Tensor] = None,
32
+ ):
33
+ """
34
+ Apply layer normalization to the input tensor with Triton acceleration.
35
+
36
+ Args:
37
+ x (`torch.Tensor`):
38
+ Input tensor to normalize.
39
+ weight (`torch.Tensor`):
40
+ Scale parameter for normalization.
41
+ bias (`torch.Tensor`):
42
+ Shift parameter for normalization.
43
+ residual (`torch.Tensor`, *optional*):
44
+ Optional residual tensor to add to the input before normalization.
45
+ x1 (`torch.Tensor`, *optional*):
46
+ Optional second input tensor to combine with `x`. When provided, the function
47
+ first adds `x1` to `x` and then applies normalization.
48
+ weight1 (`torch.Tensor`, *optional*):
49
+ Scale parameter for the second normalization.
50
+ bias1 (`torch.Tensor`, *optional*):
51
+ Shift parameter for the second normalization.
52
+ eps (`float`, *optional*, defaults to 1e-6):
53
+ Small constant added for numerical stability in normalization.
54
+ dropout_p (`float`, *optional*, defaults to 0.0):
55
+ Dropout probability. If greater than 0, applies dropout to the input before
56
+ normalization and residual addition.
57
+ rowscale (`torch.Tensor`, *optional*):
58
+ Optional scaling factor applied to each row of the input tensor.
59
+ Not compatible with the use of `x1`.
60
+ prenorm (`bool`, *optional*, defaults to False):
61
+ If True, returns both the normalized output and the unnormalized input+residual.
62
+ residual_in_fp32 (`bool`, *optional*, defaults to False):
63
+ If True, performs the residual connection in FP32 precision.
64
+ is_rms_norm (`bool`, *optional*, defaults to False):
65
+ If True, uses RMS normalization instead of layer normalization.
66
+ return_dropout_mask (`bool`, *optional*, defaults to False):
67
+ If True, returns the dropout mask used for the computation.
68
+ out (`torch.Tensor`, *optional*):
69
+ Output tensor for the normalized result. If `None`, a new tensor is allocated.
70
+ residual_out (`torch.Tensor`, *optional*):
71
+ Output tensor for the residual result when using prenorm. If `None`, a new tensor
72
+ is allocated when needed.
73
+
74
+ Returns:
75
+ `torch.Tensor` or tuple of `torch.Tensor`:
76
+ - The normalized input.
77
+ - The second normalization of the input if `weight1` is provided.
78
+ - The residual tensor if `prenorm` is set.
79
+ - The dropout mask if `return_dropout_mask` is set.
80
+ - The dropout mask for `x1` if `x1` is provided and `return_dropout_mask` is set.
81
+ """
82
+ return layer_norm_fn(
83
+ x,
84
+ weight,
85
+ bias,
86
+ residual,
87
+ x1,
88
+ weight1,
89
+ bias1,
90
+ eps,
91
+ dropout_p,
92
+ rowscale,
93
+ prenorm,
94
+ residual_in_fp32,
95
+ is_rms_norm,
96
+ return_dropout_mask,
97
+ out=out,
98
+ residual_out=residual_out,
99
+ )
100
+
101
+
102
+ __kernel_metadata__ = {
103
+ "license": "bsd-3-clause",
104
+ }
105
+
106
 
107
+ __all__ = [
108
+ "__kernel_metadata__",
109
+ "layers",
110
+ "layer_norm",
111
+ "layer_norm_fn",
112
+ "layer_norm_linear_fn",
113
+ "rms_norm_fn",
114
+ ]
torch-ext/triton_layer_norm/__init__.py CHANGED
@@ -1,5 +1,114 @@
1
- from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
 
 
 
 
 
 
 
 
2
 
3
  from . import layers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- __all__ = ["layers", "layer_norm_fn", "layer_norm_linear_fn", "rms_norm_fn"]
 
 
 
 
 
 
 
 
1
+ """Triton layer normalization kernels
2
+
3
+ This kernel implements layers normalization using Triton. This kernel is from
4
+ the `flash-attention <https://github.com/Dao-AILab/flash-attention>`_ project.
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
 
11
  from . import layers
12
+ from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
13
+
14
+
15
+ def layer_norm(
16
+ x: torch.Tensor,
17
+ weight: torch.Tensor,
18
+ bias: torch.Tensor,
19
+ residual: Optional[torch.Tensor] = None,
20
+ x1: Optional[torch.Tensor] = None,
21
+ weight1: Optional[torch.Tensor] = None,
22
+ bias1: Optional[torch.Tensor] = None,
23
+ eps: float = 1e-6,
24
+ dropout_p: float = 0.0,
25
+ rowscale=None,
26
+ prenorm: bool = False,
27
+ residual_in_fp32: bool = False,
28
+ is_rms_norm: bool = False,
29
+ return_dropout_mask: bool = False,
30
+ out: Optional[torch.Tensor] = None,
31
+ residual_out: Optional[torch.Tensor] = None,
32
+ ):
33
+ """
34
+ Apply layer normalization to the input tensor with Triton acceleration.
35
+
36
+ Args:
37
+ x (`torch.Tensor`):
38
+ Input tensor to normalize.
39
+ weight (`torch.Tensor`):
40
+ Scale parameter for normalization.
41
+ bias (`torch.Tensor`):
42
+ Shift parameter for normalization.
43
+ residual (`torch.Tensor`, *optional*):
44
+ Optional residual tensor to add to the input before normalization.
45
+ x1 (`torch.Tensor`, *optional*):
46
+ Optional second input tensor to combine with `x`. When provided, the function
47
+ first adds `x1` to `x` and then applies normalization.
48
+ weight1 (`torch.Tensor`, *optional*):
49
+ Scale parameter for the second normalization.
50
+ bias1 (`torch.Tensor`, *optional*):
51
+ Shift parameter for the second normalization.
52
+ eps (`float`, *optional*, defaults to 1e-6):
53
+ Small constant added for numerical stability in normalization.
54
+ dropout_p (`float`, *optional*, defaults to 0.0):
55
+ Dropout probability. If greater than 0, applies dropout to the input before
56
+ normalization and residual addition.
57
+ rowscale (`torch.Tensor`, *optional*):
58
+ Optional scaling factor applied to each row of the input tensor.
59
+ Not compatible with the use of `x1`.
60
+ prenorm (`bool`, *optional*, defaults to False):
61
+ If True, returns both the normalized output and the unnormalized input+residual.
62
+ residual_in_fp32 (`bool`, *optional*, defaults to False):
63
+ If True, performs the residual connection in FP32 precision.
64
+ is_rms_norm (`bool`, *optional*, defaults to False):
65
+ If True, uses RMS normalization instead of layer normalization.
66
+ return_dropout_mask (`bool`, *optional*, defaults to False):
67
+ If True, returns the dropout mask used for the computation.
68
+ out (`torch.Tensor`, *optional*):
69
+ Output tensor for the normalized result. If `None`, a new tensor is allocated.
70
+ residual_out (`torch.Tensor`, *optional*):
71
+ Output tensor for the residual result when using prenorm. If `None`, a new tensor
72
+ is allocated when needed.
73
+
74
+ Returns:
75
+ `torch.Tensor` or tuple of `torch.Tensor`:
76
+ - The normalized input.
77
+ - The second normalization of the input if `weight1` is provided.
78
+ - The residual tensor if `prenorm` is set.
79
+ - The dropout mask if `return_dropout_mask` is set.
80
+ - The dropout mask for `x1` if `x1` is provided and `return_dropout_mask` is set.
81
+ """
82
+ return layer_norm_fn(
83
+ x,
84
+ weight,
85
+ bias,
86
+ residual,
87
+ x1,
88
+ weight1,
89
+ bias1,
90
+ eps,
91
+ dropout_p,
92
+ rowscale,
93
+ prenorm,
94
+ residual_in_fp32,
95
+ is_rms_norm,
96
+ return_dropout_mask,
97
+ out=out,
98
+ residual_out=residual_out,
99
+ )
100
+
101
+
102
+ __kernel_metadata__ = {
103
+ "license": "bsd-3-clause",
104
+ }
105
+
106
 
107
+ __all__ = [
108
+ "__kernel_metadata__",
109
+ "layers",
110
+ "layer_norm",
111
+ "layer_norm_fn",
112
+ "layer_norm_linear_fn",
113
+ "rms_norm_fn",
114
+ ]
torch-ext/triton_layer_norm/layers.py CHANGED
@@ -5,10 +5,32 @@ from .layer_norm import rms_norm_fn
5
 
6
 
7
  class LlamaRMSNorm(nn.Module):
 
 
 
 
 
 
 
 
 
 
8
  weight: torch.Tensor
9
  variance_epsilon: float
10
 
11
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
12
  return rms_norm_fn(
13
  hidden_states,
14
  self.weight,
 
5
 
6
 
7
  class LlamaRMSNorm(nn.Module):
8
+ """
9
+ RMS Layer Norm for Llama models.
10
+
11
+ Triton-optimized RMS layer norm. The interface is compatible with `LLamaRMSNorm` in
12
+ `transformers`.
13
+
14
+ Attributes:
15
+ weight (`torch.Tensor`): The learnable scaling parameter.
16
+ variance_epsilon (`float`): The epsilon value for numerical stability.
17
+ """
18
  weight: torch.Tensor
19
  variance_epsilon: float
20
 
21
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
22
+ """
23
+ Apply RMS normalization to the input hidden states.
24
+
25
+ Args:
26
+ hidden_states (`torch.Tensor`):
27
+ Input tensor of shape `(batch_size, sequence_length, hidden_size)` or any shape
28
+ where the last dimension is the feature dimension to be normalized.
29
+
30
+ Returns:
31
+ `torch.Tensor`:
32
+ The normalized tensor with the same shape as the input `hidden_states`.
33
+ """
34
  return rms_norm_fn(
35
  hidden_states,
36
  self.weight,