medmekk HF Staff commited on
Commit
3de0c84
·
verified ·
1 Parent(s): e1a49f7

Upload custom kernels

Browse files
build/torch-universal/liger_kernels/__init__.py CHANGED
@@ -1,30 +1,3 @@
1
- from .cross_entropy import LigerCrossEntropyFunction
2
- from .fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
3
- from .dyt import LigerDyTFunction
4
- from .geglu import LigerGELUMulFunction
5
- from .group_norm import LigerGroupNormFunction
6
- from .kl_div import LigerKLDivLossFunction
7
- from .layer_norm import LigerLayerNormFunction
8
- from .qwen2vl_mrope import LigerQwen2VLMRopeFunction
9
- from .rms_norm import LigerRMSNormFunction, LigerRMSNorm
10
- from .jsd import LigerJSDFunction
11
- from .rope import LigerRopeFunction
12
- from .swiglu import LigerSiLUMulFunction
13
- from .tvd import LigerTVDLossFunction
14
 
15
- __all__ = [
16
- "LigerCrossEntropyFunction",
17
- "LigerFusedLinearCrossEntropyFunction",
18
- "LigerDyTFunction",
19
- "LigerGELUMulFunction",
20
- "LigerGroupNormFunction",
21
- "LigerKLDivLossFunction",
22
- "LigerLayerNormFunction",
23
- "LigerQwen2VLMRopeFunction",
24
- "LigerRMSNormFunction",
25
- "LigerRMSNorm",
26
- "LigerJSDFunction",
27
- "LigerRopeFunction",
28
- "LigerSiLUMulFunction",
29
- "LigerTVDLossFunction",
30
- ]
 
1
+ from . import layers
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ __all__ = ["layers"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-universal/liger_kernels/_ops.py CHANGED
@@ -1,8 +1,8 @@
1
  import torch
2
- ops = torch.ops._liger_kernels_20250507091026
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
- return f"_liger_kernels_20250507091026::{op_name}"
 
1
  import torch
2
+ ops = torch.ops._liger_kernels_20250507091553
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
+ return f"_liger_kernels_20250507091553::{op_name}"
build/torch-universal/liger_kernels/layers.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .rms_norm import LigerRMSNormFunction
3
+
4
+ class LigerRMSNorm(torch.nn.Module):
5
+ """
6
+ RMSNorm module that uses the optimized LigerRMSNormFunction.
7
+
8
+ Args:
9
+ hidden_size (int): The size of the hidden dimension.
10
+ eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
11
+ offset (float, optional): Offset value to shift the weight tensor. Defaults to 0.0.
12
+ casting_mode (str, optional): The casting mode to use. Defaults to "llama".
13
+ in_place (bool, optional): Whether to modify dY in-place to store dX during backward. Defaults to True.
14
+ """
15
+
16
+
17
+ weight: torch.Tensor
18
+ variance_epsilon: float
19
+ offset: float = 0
20
+ casting_mode: str = "llama"
21
+ in_place: bool = True
22
+
23
+ def forward(self, hidden_states):
24
+ """
25
+ Apply RMS normalization to the input tensor.
26
+
27
+ Args:
28
+ hidden_states (torch.Tensor): Input tensor of shape (B, T, H) or (BxT, H)
29
+
30
+ Returns:
31
+ torch.Tensor: Normalized tensor of the same shape as input
32
+ """
33
+ return LigerRMSNormFunction.apply(
34
+ hidden_states,
35
+ self.weight,
36
+ self.variance_epsilon,
37
+ self.offset,
38
+ self.casting_mode,
39
+ self.in_place
40
+ )
41
+
42
+ __all__ = ["LigerRMSNorm"]
build/torch-universal/liger_kernels/rms_norm.py CHANGED
@@ -362,43 +362,4 @@ class LigerRMSNormFunction(torch.autograd.Function):
362
  ctx.num_warps,
363
  ctx.in_place,
364
  )
365
- return dX, dW, None, None, None, None
366
-
367
-
368
- class LigerRMSNorm(torch.nn.Module):
369
- """
370
- RMSNorm module that uses the optimized LigerRMSNormFunction.
371
-
372
- Args:
373
- hidden_size (int): The size of the hidden dimension.
374
- eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
375
- offset (float, optional): Offset value to shift the weight tensor. Defaults to 0.0.
376
- casting_mode (str, optional): The casting mode to use. Defaults to "llama".
377
- in_place (bool, optional): Whether to modify dY in-place to store dX during backward. Defaults to True.
378
- """
379
-
380
-
381
- weight: torch.Tensor
382
- variance_epsilon: float
383
- offset: float = 0
384
- casting_mode: str = "llama"
385
- in_place: bool = True
386
-
387
- def forward(self, hidden_states):
388
- """
389
- Apply RMS normalization to the input tensor.
390
-
391
- Args:
392
- hidden_states (torch.Tensor): Input tensor of shape (B, T, H) or (BxT, H)
393
-
394
- Returns:
395
- torch.Tensor: Normalized tensor of the same shape as input
396
- """
397
- return LigerRMSNormFunction.apply(
398
- hidden_states,
399
- self.weight,
400
- self.variance_epsilon,
401
- self.offset,
402
- self.casting_mode,
403
- self.in_place
404
- )
 
362
  ctx.num_warps,
363
  ctx.in_place,
364
  )
365
+ return dX, dW, None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/liger_kernels/__init__.py CHANGED
@@ -1,30 +1,3 @@
1
- from .cross_entropy import LigerCrossEntropyFunction
2
- from .fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
3
- from .dyt import LigerDyTFunction
4
- from .geglu import LigerGELUMulFunction
5
- from .group_norm import LigerGroupNormFunction
6
- from .kl_div import LigerKLDivLossFunction
7
- from .layer_norm import LigerLayerNormFunction
8
- from .qwen2vl_mrope import LigerQwen2VLMRopeFunction
9
- from .rms_norm import LigerRMSNormFunction, LigerRMSNorm
10
- from .jsd import LigerJSDFunction
11
- from .rope import LigerRopeFunction
12
- from .swiglu import LigerSiLUMulFunction
13
- from .tvd import LigerTVDLossFunction
14
 
15
- __all__ = [
16
- "LigerCrossEntropyFunction",
17
- "LigerFusedLinearCrossEntropyFunction",
18
- "LigerDyTFunction",
19
- "LigerGELUMulFunction",
20
- "LigerGroupNormFunction",
21
- "LigerKLDivLossFunction",
22
- "LigerLayerNormFunction",
23
- "LigerQwen2VLMRopeFunction",
24
- "LigerRMSNormFunction",
25
- "LigerRMSNorm",
26
- "LigerJSDFunction",
27
- "LigerRopeFunction",
28
- "LigerSiLUMulFunction",
29
- "LigerTVDLossFunction",
30
- ]
 
1
+ from . import layers
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ __all__ = ["layers"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/liger_kernels/layers.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .rms_norm import LigerRMSNormFunction
3
+
4
+ class LigerRMSNorm(torch.nn.Module):
5
+ """
6
+ RMSNorm module that uses the optimized LigerRMSNormFunction.
7
+
8
+ Args:
9
+ hidden_size (int): The size of the hidden dimension.
10
+ eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
11
+ offset (float, optional): Offset value to shift the weight tensor. Defaults to 0.0.
12
+ casting_mode (str, optional): The casting mode to use. Defaults to "llama".
13
+ in_place (bool, optional): Whether to modify dY in-place to store dX during backward. Defaults to True.
14
+ """
15
+
16
+
17
+ weight: torch.Tensor
18
+ variance_epsilon: float
19
+ offset: float = 0
20
+ casting_mode: str = "llama"
21
+ in_place: bool = True
22
+
23
+ def forward(self, hidden_states):
24
+ """
25
+ Apply RMS normalization to the input tensor.
26
+
27
+ Args:
28
+ hidden_states (torch.Tensor): Input tensor of shape (B, T, H) or (BxT, H)
29
+
30
+ Returns:
31
+ torch.Tensor: Normalized tensor of the same shape as input
32
+ """
33
+ return LigerRMSNormFunction.apply(
34
+ hidden_states,
35
+ self.weight,
36
+ self.variance_epsilon,
37
+ self.offset,
38
+ self.casting_mode,
39
+ self.in_place
40
+ )
41
+
42
+ __all__ = ["LigerRMSNorm"]
torch-ext/liger_kernels/rms_norm.py CHANGED
@@ -362,43 +362,4 @@ class LigerRMSNormFunction(torch.autograd.Function):
362
  ctx.num_warps,
363
  ctx.in_place,
364
  )
365
- return dX, dW, None, None, None, None
366
-
367
-
368
- class LigerRMSNorm(torch.nn.Module):
369
- """
370
- RMSNorm module that uses the optimized LigerRMSNormFunction.
371
-
372
- Args:
373
- hidden_size (int): The size of the hidden dimension.
374
- eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
375
- offset (float, optional): Offset value to shift the weight tensor. Defaults to 0.0.
376
- casting_mode (str, optional): The casting mode to use. Defaults to "llama".
377
- in_place (bool, optional): Whether to modify dY in-place to store dX during backward. Defaults to True.
378
- """
379
-
380
-
381
- weight: torch.Tensor
382
- variance_epsilon: float
383
- offset: float = 0
384
- casting_mode: str = "llama"
385
- in_place: bool = True
386
-
387
- def forward(self, hidden_states):
388
- """
389
- Apply RMS normalization to the input tensor.
390
-
391
- Args:
392
- hidden_states (torch.Tensor): Input tensor of shape (B, T, H) or (BxT, H)
393
-
394
- Returns:
395
- torch.Tensor: Normalized tensor of the same shape as input
396
- """
397
- return LigerRMSNormFunction.apply(
398
- hidden_states,
399
- self.weight,
400
- self.variance_epsilon,
401
- self.offset,
402
- self.casting_mode,
403
- self.in_place
404
- )
 
362
  ctx.num_warps,
363
  ctx.in_place,
364
  )
365
+ return dX, dW, None, None, None, None