|
import torch |
|
from torch import nn |
|
import triton |
|
import triton.language as tl |
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm |
|
|
|
@triton.jit |
|
def rms_norm_fwd_fused( |
|
X, |
|
Y, |
|
W, |
|
stride, |
|
N, |
|
eps, |
|
BLOCK_SIZE: tl.constexpr, |
|
): |
|
|
|
row = tl.program_id(0) |
|
Y += row * stride |
|
X += row * stride |
|
|
|
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) |
|
for off in range(0, N, BLOCK_SIZE): |
|
cols = off + tl.arange(0, BLOCK_SIZE) |
|
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) |
|
x = tl.where(cols < N, x, 0.) |
|
_var += x * x |
|
var = tl.sum(_var, axis=0) / N |
|
rstd = 1 / tl.sqrt(var + eps) |
|
|
|
for off in range(0, N, BLOCK_SIZE): |
|
cols = off + tl.arange(0, BLOCK_SIZE) |
|
mask = cols < N |
|
w = tl.load(W + cols, mask=mask) |
|
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) |
|
x_hat = x * rstd |
|
y = x_hat * w |
|
|
|
tl.store(Y + cols, y, mask=mask) |
|
|
|
class TritonLlamaRMSNorm(nn.Module): |
|
def __init__(self, weight, eps=1e-6): |
|
""" |
|
LlamaRMSNorm is equivalent to T5LayerNorm |
|
""" |
|
super().__init__() |
|
self.weight = weight |
|
self.variance_epsilon = eps |
|
|
|
def forward(self, x): |
|
y = torch.empty_like(x) |
|
|
|
x_arg = x.reshape(-1, x.shape[-1]) |
|
M, N = x_arg.shape |
|
|
|
MAX_FUSED_SIZE = 65536 // x.element_size() |
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) |
|
if N > BLOCK_SIZE: |
|
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") |
|
|
|
num_warps = min(max(BLOCK_SIZE // 256, 1), 8) |
|
|
|
rms_norm_fwd_fused[(M,)](x_arg, y, self.weight, |
|
x_arg.stride(0), N, self.variance_epsilon, |
|
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) |
|
return y |
|
|
|
|
|
def make_quant_norm(model): |
|
""" |
|
Replace all LlamaRMSNorm modules with TritonLlamaRMSNorm modules |
|
""" |
|
|
|
for name, m in model.named_modules(): |
|
if not isinstance(m, LlamaRMSNorm): |
|
continue |
|
|
|
norm = TritonLlamaRMSNorm(m.weight, m.variance_epsilon) |
|
|
|
if '.' in name: |
|
parent_name = name.rsplit('.', 1)[0] |
|
child_name = name[len(parent_name) + 1:] |
|
parent = model.get_submodule(parent_name) |
|
else: |
|
parent_name = '' |
|
parent = model |
|
child_name = name |
|
|
|
|
|
|
|
setattr(parent, child_name, norm) |
|
|