pr-include-rev-in-flake

#1
by drbh HF Staff - opened
README.md CHANGED
@@ -1,80 +1,9 @@
1
- ---
2
- license: bsd-3-clause
3
- tags:
4
- - kernel
5
- ---
6
- # Triton layer normalization kernels.
7
-
8
- This kernel implements layers normalization using Triton. This kernel is from
9
- the [flash-attention](https://github.com/Dao-AILab/flash-attention) project.
10
-
11
- ## Functions
12
-
13
- ### Function `layer_norm`
14
-
15
- `(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, residual: Optional[torch.Tensor] = None, x1: Optional[torch.Tensor] = None, weight1: Optional[torch.Tensor] = None, bias1: Optional[torch.Tensor] = None, eps: float = 1e-06, dropout_p: float = 0.0, rowscale=None, prenorm: bool = False, residual_in_fp32: bool = False, is_rms_norm: bool = False, return_dropout_mask: bool = False, out: Optional[torch.Tensor] = None, residual_out: Optional[torch.Tensor] = None)`
16
-
17
- Apply layer normalization to the input tensor with Triton acceleration.
18
-
19
- ### Parameters
20
-
21
- - **x** (*torch.Tensor*) --
22
- Input tensor to normalize.
23
- - **weight** (*torch.Tensor*) --
24
- Scale parameter for normalization.
25
- - **bias** (*torch.Tensor*) --
26
- Shift parameter for normalization.
27
- - **residual** (*torch.Tensor*, *optional*) --
28
- Optional residual tensor to add to the input before normalization.
29
- - **x1** (*torch.Tensor*, *optional*) --
30
- Optional second input tensor to combine with *x*. When provided, the function
31
- first adds *x1* to *x* and then applies normalization.
32
- - **weight1** (*torch.Tensor*, *optional*) --
33
- Scale parameter for the second normalization.
34
- - **bias1** (*torch.Tensor*, *optional*) --
35
- Shift parameter for the second normalization.
36
- - **eps** (*float*, *optional*, defaults to 1e-6) --
37
- Small constant added for numerical stability in normalization.
38
- - **dropout_p** (*float*, *optional*, defaults to 0.0) --
39
- Dropout probability. If greater than 0, applies dropout to the input before
40
- normalization and residual addition.
41
- - **rowscale** (*torch.Tensor*, *optional*) --
42
- Optional scaling factor applied to each row of the input tensor.
43
- Not compatible with the use of *x1*.
44
- - **prenorm** (*bool*, *optional*, defaults to False) --
45
- If True, returns both the normalized output and the unnormalized input+residual.
46
- - **residual_in_fp32** (*bool*, *optional*, defaults to False) --
47
- If True, performs the residual connection in FP32 precision.
48
- - **is_rms_norm** (*bool*, *optional*, defaults to False) --
49
- If True, uses RMS normalization instead of layer normalization.
50
- - **return_dropout_mask** (*bool*, *optional*, defaults to False) --
51
- If True, returns the dropout mask used for the computation.
52
- - **out** (*torch.Tensor*, *optional*) --
53
- Output tensor for the normalized result. If *None*, a new tensor is allocated.
54
- - **residual_out** (*torch.Tensor*, *optional*) --
55
- Output tensor for the residual result when using prenorm. If *None*, a new tensor
56
- is allocated when needed.
57
-
58
- ### Returns
59
-
60
- **Type**: *torch.Tensor* or tuple of *torch.Tensor*
61
-
62
- - The normalized input.
63
- - The second normalization of the input if *weight1* is provided.
64
- - The residual tensor if *prenorm* is set.
65
- - The dropout mask if *return_dropout_mask* is set.
66
- - The dropout mask for *x1* if *x1* is provided and *return_dropout_mask* is set.
67
-
68
- ## Layers
69
-
70
- ### Class `LlamaRMSNorm`
71
-
72
- No documentation available.
73
-
74
- #### Methods
75
-
76
- ##### Method `forward`
77
-
78
- `(self, hidden_states: torch.Tensor) -> torch.Tensor`
79
-
80
- No documentation available.
 
1
+ ---
2
+ license: bsd-3-clause
3
+ tags:
4
+ - kernel
5
+ ---
6
+
7
+ ## triton-layer-norm
8
+
9
+ Triton layer norm [from flash-attention](https://github.com/Dao-AILab/flash-attention).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build.toml CHANGED
@@ -1,3 +1,5 @@
1
  [general]
2
  name = "triton_layer_norm"
 
 
3
  universal = true
 
1
  [general]
2
  name = "triton_layer_norm"
3
+
4
+ [torch]
5
  universal = true
build/torch-universal/triton_layer_norm/__init__.py CHANGED
@@ -1,114 +1,5 @@
1
- """Triton layer normalization kernels
2
-
3
- This kernel implements layers normalization using Triton. This kernel is from
4
- the `flash-attention <https://github.com/Dao-AILab/flash-attention>`_ project.
5
- """
6
-
7
- from typing import Optional
8
-
9
- import torch
10
-
11
- from . import layers
12
  from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
13
 
 
14
 
15
- def layer_norm(
16
- x: torch.Tensor,
17
- weight: torch.Tensor,
18
- bias: torch.Tensor,
19
- residual: Optional[torch.Tensor] = None,
20
- x1: Optional[torch.Tensor] = None,
21
- weight1: Optional[torch.Tensor] = None,
22
- bias1: Optional[torch.Tensor] = None,
23
- eps: float = 1e-6,
24
- dropout_p: float = 0.0,
25
- rowscale=None,
26
- prenorm: bool = False,
27
- residual_in_fp32: bool = False,
28
- is_rms_norm: bool = False,
29
- return_dropout_mask: bool = False,
30
- out: Optional[torch.Tensor] = None,
31
- residual_out: Optional[torch.Tensor] = None,
32
- ):
33
- """
34
- Apply layer normalization to the input tensor with Triton acceleration.
35
-
36
- Args:
37
- x (`torch.Tensor`):
38
- Input tensor to normalize.
39
- weight (`torch.Tensor`):
40
- Scale parameter for normalization.
41
- bias (`torch.Tensor`):
42
- Shift parameter for normalization.
43
- residual (`torch.Tensor`, *optional*):
44
- Optional residual tensor to add to the input before normalization.
45
- x1 (`torch.Tensor`, *optional*):
46
- Optional second input tensor to combine with `x`. When provided, the function
47
- first adds `x1` to `x` and then applies normalization.
48
- weight1 (`torch.Tensor`, *optional*):
49
- Scale parameter for the second normalization.
50
- bias1 (`torch.Tensor`, *optional*):
51
- Shift parameter for the second normalization.
52
- eps (`float`, *optional*, defaults to 1e-6):
53
- Small constant added for numerical stability in normalization.
54
- dropout_p (`float`, *optional*, defaults to 0.0):
55
- Dropout probability. If greater than 0, applies dropout to the input before
56
- normalization and residual addition.
57
- rowscale (`torch.Tensor`, *optional*):
58
- Optional scaling factor applied to each row of the input tensor.
59
- Not compatible with the use of `x1`.
60
- prenorm (`bool`, *optional*, defaults to False):
61
- If True, returns both the normalized output and the unnormalized input+residual.
62
- residual_in_fp32 (`bool`, *optional*, defaults to False):
63
- If True, performs the residual connection in FP32 precision.
64
- is_rms_norm (`bool`, *optional*, defaults to False):
65
- If True, uses RMS normalization instead of layer normalization.
66
- return_dropout_mask (`bool`, *optional*, defaults to False):
67
- If True, returns the dropout mask used for the computation.
68
- out (`torch.Tensor`, *optional*):
69
- Output tensor for the normalized result. If `None`, a new tensor is allocated.
70
- residual_out (`torch.Tensor`, *optional*):
71
- Output tensor for the residual result when using prenorm. If `None`, a new tensor
72
- is allocated when needed.
73
-
74
- Returns:
75
- `torch.Tensor` or tuple of `torch.Tensor`:
76
- - The normalized input.
77
- - The second normalization of the input if `weight1` is provided.
78
- - The residual tensor if `prenorm` is set.
79
- - The dropout mask if `return_dropout_mask` is set.
80
- - The dropout mask for `x1` if `x1` is provided and `return_dropout_mask` is set.
81
- """
82
- return layer_norm_fn(
83
- x,
84
- weight,
85
- bias,
86
- residual,
87
- x1,
88
- weight1,
89
- bias1,
90
- eps,
91
- dropout_p,
92
- rowscale,
93
- prenorm,
94
- residual_in_fp32,
95
- is_rms_norm,
96
- return_dropout_mask,
97
- out=out,
98
- residual_out=residual_out,
99
- )
100
-
101
-
102
- __kernel_metadata__ = {
103
- "license": "bsd-3-clause",
104
- }
105
-
106
-
107
- __all__ = [
108
- "__kernel_metadata__",
109
- "layers",
110
- "layer_norm",
111
- "layer_norm_fn",
112
- "layer_norm_linear_fn",
113
- "rms_norm_fn",
114
- ]
 
 
 
 
 
 
 
 
 
 
 
 
1
  from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
2
 
3
+ from . import layers
4
 
5
+ __all__ = ["layers", "layer_norm_fn", "layer_norm_linear_fn", "rms_norm_fn"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-universal/triton_layer_norm/_ops.py DELETED
@@ -1,8 +0,0 @@
1
- import torch
2
- ops = torch.ops._triton_layer_norm_4dc3a9b_dirty
3
-
4
- def add_op_namespace_prefix(op_name: str):
5
- """
6
- Prefix op by namespace.
7
- """
8
- return f"_triton_layer_norm_4dc3a9b_dirty::{op_name}"
 
 
 
 
 
 
 
 
 
build/torch-universal/triton_layer_norm/layers.py CHANGED
@@ -1,46 +1,4 @@
1
- import torch
2
- from torch import nn
3
 
4
- from .layer_norm import rms_norm_fn
5
 
6
-
7
- class LlamaRMSNorm(nn.Module):
8
- """
9
- RMS Layer Norm for Llama models.
10
-
11
- Triton-optimized RMS layer norm. The interface is compatible with `LLamaRMSNorm` in
12
- `transformers`.
13
-
14
- Attributes:
15
- weight (`torch.Tensor`): The learnable scaling parameter.
16
- variance_epsilon (`float`): The epsilon value for numerical stability.
17
- """
18
- weight: torch.Tensor
19
- variance_epsilon: float
20
-
21
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
22
- """
23
- Apply RMS normalization to the input hidden states.
24
-
25
- Args:
26
- hidden_states (`torch.Tensor`):
27
- Input tensor of shape `(batch_size, sequence_length, hidden_size)` or any shape
28
- where the last dimension is the feature dimension to be normalized.
29
-
30
- Returns:
31
- `torch.Tensor`:
32
- The normalized tensor with the same shape as the input `hidden_states`.
33
- """
34
- return rms_norm_fn(
35
- hidden_states,
36
- self.weight,
37
- bias=None,
38
- residual=None,
39
- eps=self.variance_epsilon,
40
- dropout_p=0.0,
41
- prenorm=False,
42
- residual_in_fp32=False,
43
- )
44
-
45
-
46
- __all__ = ["LlamaRMSNorm"]
 
1
+ from .layer_norm import RMSNorm
 
2
 
 
3
 
4
+ __all__ = ["RMSNorm"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.lock DELETED
@@ -1,168 +0,0 @@
1
- {
2
- "nodes": {
3
- "flake-compat": {
4
- "locked": {
5
- "lastModified": 1747046372,
6
- "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
- "owner": "edolstra",
8
- "repo": "flake-compat",
9
- "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
- "type": "github"
11
- },
12
- "original": {
13
- "owner": "edolstra",
14
- "repo": "flake-compat",
15
- "type": "github"
16
- }
17
- },
18
- "flake-compat_2": {
19
- "locked": {
20
- "lastModified": 1733328505,
21
- "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
- "owner": "edolstra",
23
- "repo": "flake-compat",
24
- "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
- "type": "github"
26
- },
27
- "original": {
28
- "owner": "edolstra",
29
- "repo": "flake-compat",
30
- "type": "github"
31
- }
32
- },
33
- "flake-utils": {
34
- "inputs": {
35
- "systems": "systems"
36
- },
37
- "locked": {
38
- "lastModified": 1731533236,
39
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
- "owner": "numtide",
41
- "repo": "flake-utils",
42
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
- "type": "github"
44
- },
45
- "original": {
46
- "owner": "numtide",
47
- "repo": "flake-utils",
48
- "type": "github"
49
- }
50
- },
51
- "flake-utils_2": {
52
- "inputs": {
53
- "systems": "systems_2"
54
- },
55
- "locked": {
56
- "lastModified": 1731533236,
57
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
- "owner": "numtide",
59
- "repo": "flake-utils",
60
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
- "type": "github"
62
- },
63
- "original": {
64
- "owner": "numtide",
65
- "repo": "flake-utils",
66
- "type": "github"
67
- }
68
- },
69
- "hf-nix": {
70
- "inputs": {
71
- "flake-compat": "flake-compat_2",
72
- "flake-utils": "flake-utils_2",
73
- "nixpkgs": "nixpkgs"
74
- },
75
- "locked": {
76
- "lastModified": 1750234878,
77
- "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
78
- "owner": "huggingface",
79
- "repo": "hf-nix",
80
- "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
81
- "type": "github"
82
- },
83
- "original": {
84
- "owner": "huggingface",
85
- "repo": "hf-nix",
86
- "type": "github"
87
- }
88
- },
89
- "kernel-builder": {
90
- "inputs": {
91
- "flake-compat": "flake-compat",
92
- "flake-utils": "flake-utils",
93
- "hf-nix": "hf-nix",
94
- "nixpkgs": [
95
- "kernel-builder",
96
- "hf-nix",
97
- "nixpkgs"
98
- ]
99
- },
100
- "locked": {
101
- "lastModified": 1750409351,
102
- "narHash": "sha256-xkzrwee77LrBDtwNNihBkYbY7yUwdOv0/4+J3B5xCZE=",
103
- "owner": "huggingface",
104
- "repo": "kernel-builder",
105
- "rev": "9e61fba877153bffa6eaff023243fd81220c0eea",
106
- "type": "github"
107
- },
108
- "original": {
109
- "owner": "huggingface",
110
- "repo": "kernel-builder",
111
- "type": "github"
112
- }
113
- },
114
- "nixpkgs": {
115
- "locked": {
116
- "lastModified": 1747820358,
117
- "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
118
- "owner": "danieldk",
119
- "repo": "nixpkgs",
120
- "rev": "d3c1681180717528068082103bf323147de6ab0b",
121
- "type": "github"
122
- },
123
- "original": {
124
- "owner": "danieldk",
125
- "ref": "cudatoolkit-12.9-kernel-builder",
126
- "repo": "nixpkgs",
127
- "type": "github"
128
- }
129
- },
130
- "root": {
131
- "inputs": {
132
- "kernel-builder": "kernel-builder"
133
- }
134
- },
135
- "systems": {
136
- "locked": {
137
- "lastModified": 1681028828,
138
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
- "owner": "nix-systems",
140
- "repo": "default",
141
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
- "type": "github"
143
- },
144
- "original": {
145
- "owner": "nix-systems",
146
- "repo": "default",
147
- "type": "github"
148
- }
149
- },
150
- "systems_2": {
151
- "locked": {
152
- "lastModified": 1681028828,
153
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
- "owner": "nix-systems",
155
- "repo": "default",
156
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
- "type": "github"
158
- },
159
- "original": {
160
- "owner": "nix-systems",
161
- "repo": "default",
162
- "type": "github"
163
- }
164
- }
165
- },
166
- "root": "root",
167
- "version": 7
168
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.nix CHANGED
@@ -10,8 +10,5 @@
10
  self,
11
  kernel-builder,
12
  }:
13
- kernel-builder.lib.genFlakeOutputs {
14
- path = ./.;
15
- rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
- };
17
  }
 
10
  self,
11
  kernel-builder,
12
  }:
13
+ kernel-builder.lib.genFlakeOutputs ./.;
 
 
 
14
  }
torch-ext/triton_layer_norm/__init__.py CHANGED
@@ -1,114 +1,5 @@
1
- """Triton layer normalization kernels
2
-
3
- This kernel implements layers normalization using Triton. This kernel is from
4
- the `flash-attention <https://github.com/Dao-AILab/flash-attention>`_ project.
5
- """
6
-
7
- from typing import Optional
8
-
9
- import torch
10
-
11
- from . import layers
12
  from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
13
 
 
14
 
15
- def layer_norm(
16
- x: torch.Tensor,
17
- weight: torch.Tensor,
18
- bias: torch.Tensor,
19
- residual: Optional[torch.Tensor] = None,
20
- x1: Optional[torch.Tensor] = None,
21
- weight1: Optional[torch.Tensor] = None,
22
- bias1: Optional[torch.Tensor] = None,
23
- eps: float = 1e-6,
24
- dropout_p: float = 0.0,
25
- rowscale=None,
26
- prenorm: bool = False,
27
- residual_in_fp32: bool = False,
28
- is_rms_norm: bool = False,
29
- return_dropout_mask: bool = False,
30
- out: Optional[torch.Tensor] = None,
31
- residual_out: Optional[torch.Tensor] = None,
32
- ):
33
- """
34
- Apply layer normalization to the input tensor with Triton acceleration.
35
-
36
- Args:
37
- x (`torch.Tensor`):
38
- Input tensor to normalize.
39
- weight (`torch.Tensor`):
40
- Scale parameter for normalization.
41
- bias (`torch.Tensor`):
42
- Shift parameter for normalization.
43
- residual (`torch.Tensor`, *optional*):
44
- Optional residual tensor to add to the input before normalization.
45
- x1 (`torch.Tensor`, *optional*):
46
- Optional second input tensor to combine with `x`. When provided, the function
47
- first adds `x1` to `x` and then applies normalization.
48
- weight1 (`torch.Tensor`, *optional*):
49
- Scale parameter for the second normalization.
50
- bias1 (`torch.Tensor`, *optional*):
51
- Shift parameter for the second normalization.
52
- eps (`float`, *optional*, defaults to 1e-6):
53
- Small constant added for numerical stability in normalization.
54
- dropout_p (`float`, *optional*, defaults to 0.0):
55
- Dropout probability. If greater than 0, applies dropout to the input before
56
- normalization and residual addition.
57
- rowscale (`torch.Tensor`, *optional*):
58
- Optional scaling factor applied to each row of the input tensor.
59
- Not compatible with the use of `x1`.
60
- prenorm (`bool`, *optional*, defaults to False):
61
- If True, returns both the normalized output and the unnormalized input+residual.
62
- residual_in_fp32 (`bool`, *optional*, defaults to False):
63
- If True, performs the residual connection in FP32 precision.
64
- is_rms_norm (`bool`, *optional*, defaults to False):
65
- If True, uses RMS normalization instead of layer normalization.
66
- return_dropout_mask (`bool`, *optional*, defaults to False):
67
- If True, returns the dropout mask used for the computation.
68
- out (`torch.Tensor`, *optional*):
69
- Output tensor for the normalized result. If `None`, a new tensor is allocated.
70
- residual_out (`torch.Tensor`, *optional*):
71
- Output tensor for the residual result when using prenorm. If `None`, a new tensor
72
- is allocated when needed.
73
-
74
- Returns:
75
- `torch.Tensor` or tuple of `torch.Tensor`:
76
- - The normalized input.
77
- - The second normalization of the input if `weight1` is provided.
78
- - The residual tensor if `prenorm` is set.
79
- - The dropout mask if `return_dropout_mask` is set.
80
- - The dropout mask for `x1` if `x1` is provided and `return_dropout_mask` is set.
81
- """
82
- return layer_norm_fn(
83
- x,
84
- weight,
85
- bias,
86
- residual,
87
- x1,
88
- weight1,
89
- bias1,
90
- eps,
91
- dropout_p,
92
- rowscale,
93
- prenorm,
94
- residual_in_fp32,
95
- is_rms_norm,
96
- return_dropout_mask,
97
- out=out,
98
- residual_out=residual_out,
99
- )
100
-
101
-
102
- __kernel_metadata__ = {
103
- "license": "bsd-3-clause",
104
- }
105
-
106
-
107
- __all__ = [
108
- "__kernel_metadata__",
109
- "layers",
110
- "layer_norm",
111
- "layer_norm_fn",
112
- "layer_norm_linear_fn",
113
- "rms_norm_fn",
114
- ]
 
 
 
 
 
 
 
 
 
 
 
 
1
  from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
2
 
3
+ from . import layers
4
 
5
+ __all__ = ["layers", "layer_norm_fn", "layer_norm_linear_fn", "rms_norm_fn"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/triton_layer_norm/layers.py CHANGED
@@ -1,46 +1,4 @@
1
- import torch
2
- from torch import nn
3
 
4
- from .layer_norm import rms_norm_fn
5
 
6
-
7
- class LlamaRMSNorm(nn.Module):
8
- """
9
- RMS Layer Norm for Llama models.
10
-
11
- Triton-optimized RMS layer norm. The interface is compatible with `LLamaRMSNorm` in
12
- `transformers`.
13
-
14
- Attributes:
15
- weight (`torch.Tensor`): The learnable scaling parameter.
16
- variance_epsilon (`float`): The epsilon value for numerical stability.
17
- """
18
- weight: torch.Tensor
19
- variance_epsilon: float
20
-
21
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
22
- """
23
- Apply RMS normalization to the input hidden states.
24
-
25
- Args:
26
- hidden_states (`torch.Tensor`):
27
- Input tensor of shape `(batch_size, sequence_length, hidden_size)` or any shape
28
- where the last dimension is the feature dimension to be normalized.
29
-
30
- Returns:
31
- `torch.Tensor`:
32
- The normalized tensor with the same shape as the input `hidden_states`.
33
- """
34
- return rms_norm_fn(
35
- hidden_states,
36
- self.weight,
37
- bias=None,
38
- residual=None,
39
- eps=self.variance_epsilon,
40
- dropout_p=0.0,
41
- prenorm=False,
42
- residual_in_fp32=False,
43
- )
44
-
45
-
46
- __all__ = ["LlamaRMSNorm"]
 
1
+ from .layer_norm import RMSNorm
 
2
 
 
3
 
4
+ __all__ = ["RMSNorm"]