diff --git "a/rwkv/model.py" "b/rwkv/model.py" new file mode 100644--- /dev/null +++ "b/rwkv/model.py" @@ -0,0 +1,3049 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +from typing import Optional +import types, gc, os, time, re, math +import torch +import torch.nn as nn +from torch.nn import functional as F + +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True +current_path = os.path.dirname(os.path.abspath(__file__)) + +######################################################################################################## + +if os.environ.get("RWKV_JIT_ON") != "0": + os.environ["RWKV_JIT_ON"] = "1" + MyModule = torch.jit.ScriptModule + MyFunction = torch.jit.script_method + MyStatic = torch.jit.script +else: + MyModule = torch.nn.Module + + def __nop(ob): + return ob + + MyFunction = __nop + MyStatic = __nop + +if os.environ.get("RWKV_CUDA_ON") == "1": + from torch.utils.cpp_extension import load + + try: + load( + name=f"wkv_cuda", + sources=[ + f"{current_path}/cuda/wrapper.cpp", + f"{current_path}/cuda/operators.cu", + f"{current_path}/cuda/gemm_fp16_cublas.cpp", + ], + verbose=True, + extra_ldflags=["cublas.lib" if os.name == "nt" else ""], + extra_cuda_cflags=[ + "--use_fast_math", + "-O3", + "--extra-device-vectorization", + ], + is_python_module=False, + ) + DISABLE_CUBLAS_GEMM = False + except: + print( + "Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow." + ) + load( + name=f"wkv_cuda", + sources=[ + f"{current_path}/cuda/wrapper.cpp", + f"{current_path}/cuda/operators.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "--use_fast_math", + "-O3", + "--extra-device-vectorization", + ], + extra_cflags=["-DDISABLE_CUBLAS_GEMM"], + is_python_module=False, + ) + DISABLE_CUBLAS_GEMM = True + + @MyStatic + def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp): + assert 1 * C % min(C, 32) == 0 + assert ( + k.dtype == v.dtype == torch.float16 or k.dtype == v.dtype == torch.float32 + ) + assert w.dtype == u.dtype == aa.dtype == bb.dtype == pp.dtype == torch.float32 + w = w.contiguous() + u = u.contiguous() + k = k.contiguous() + v = v.contiguous() + y = torch.empty( + (T, C), + device=w.device, + memory_format=torch.contiguous_format, + dtype=k.dtype, + ) + torch.ops.rwkv.wkv_forward(1, T, C, w, u, k, v, y, aa, bb, pp) + return y, aa, bb, pp + + @MyStatic + def cuda_mm8_seq(B: int, N: int, M: int, x, w, mx, rx, my, ry): + assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype + assert x.dtype == torch.float32 or x.dtype == torch.float16 + assert w.dtype == torch.uint8 + assert x.shape == (B, N) + assert w.shape == (N, M) + assert rx.shape == mx.shape == (M,) + assert ry.shape == my.shape == (N, 1) + y = torch.empty((B, M), device=w.device, dtype=x.dtype) + torch.ops.rwkv.mm8_seq(B, N, M, x, w, mx, rx, my, ry, y) + return y + + @MyStatic + def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry): + assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype + assert x.dtype == torch.float32 or x.dtype == torch.float16 + assert w.dtype == torch.uint8 + assert x.shape == (N,) + assert w.shape == (N, M) + assert rx.shape == mx.shape == (M,) + assert ry.shape == my.shape == (N, 1) + y = torch.zeros((M,), device=w.device, dtype=torch.float32) + torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y) + return y.to(dtype=x.dtype) + +else: + os.environ["RWKV_CUDA_ON"] = "0" + + +@MyStatic +def torch_mm8_seq(x, w, mx, rx, my, ry): + return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) + + +@MyStatic +def torch_mm8_one(x, w, mx, rx, my, ry): + return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) + + +if os.environ.get("RWKV_CUDA_ON") == "1": + + @MyStatic + def mm8_seq(x, w, mx, rx, my, ry): + if w.device.type == "cuda" and x.dtype == torch.float16: + B, N, M = x.shape[0], w.shape[0], w.shape[1] + return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry) + else: + return torch_mm8_seq(x, w, mx, rx, my, ry) + + @MyStatic + def mm8_one(x, w, mx, rx, my, ry): + if w.device.type == "cuda": + N, M = w.shape[0], w.shape[1] + return cuda_mm8_one(N, M, x, w, mx, rx, my, ry) + else: + return torch_mm8_one(x, w, mx, rx, my, ry) + +else: + + @MyStatic + def mm8_seq(x, w, mx, rx, my, ry): + return torch_mm8_seq(x, w, mx, rx, my, ry) + + @MyStatic + def mm8_one(x, w, mx, rx, my, ry): + return torch_mm8_one(x, w, mx, rx, my, ry) + + +def mm8( + x: torch.Tensor, + w: torch.Tensor, + mx: torch.Tensor, + rx: torch.Tensor, + my: torch.Tensor, + ry: torch.Tensor, +): + if len(x.shape) == 1: + return mm8_one(x, w, mx, rx, my, ry) + return mm8_seq(x, w, mx, rx, my, ry) + + +def matmul( + a, + b, + mx: Optional[torch.Tensor] = None, + rx: Optional[torch.Tensor] = None, + my: Optional[torch.Tensor] = None, + ry: Optional[torch.Tensor] = None, + output_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + if output_dtype is None: + output_dtype = a.dtype + if b.dtype in [torch.float16, torch.bfloat16, torch.float32]: + assert a.dtype == b.dtype + return matmul_float(a, b, output_dtype=output_dtype) + elif b.dtype == torch.uint8: + assert mx is not None + assert rx is not None + assert my is not None + assert ry is not None + return mm8(a, b, mx, rx, my, ry).to(output_dtype) + else: + raise ValueError("Unsupported dtype") + + +if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM: + + def matmul_float(a, b, output_dtype: Optional[torch.dtype] = None): + if output_dtype is None: + output_dtype = a.dtype + if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda": + if len(a.shape) == 1: + assert len(b.shape) == 2 + c = torch.empty((b.shape[-1],), dtype=output_dtype, device=a.device) + a = a.unsqueeze(0) + else: + assert len(a.shape) == len(b.shape) + assert len(a.shape) == 2 or len(a.shape) == 3 + # torch.empty((*a.shape[:-1], b.shape[-1])) doesn't work with jit + if len(a.shape) == 2: + c = torch.empty( + (a.shape[0], b.shape[-1]), dtype=output_dtype, device=a.device + ) + else: + c = torch.empty( + (a.shape[0], a.shape[1], b.shape[-1]), + dtype=output_dtype, + device=a.device, + ) + torch.ops.rwkv.gemm_fp16_cublas(a, b, c) + return c + else: + return (a @ b).to(output_dtype) + +else: + + def matmul_float(a, b, output_dtype: Optional[torch.dtype] = None): + return (a @ b).to(output_dtype) + + +if os.environ.get("RWKV_DML_ON") == "1": + import torch_directml + + print("PyTorch with DirectML Enabled") + +if os.environ.get("RWKV_V7_ON") == "1": + + print(f'\n### RWKV-7 "Goose" enabled ###\n') + + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + # torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True + # torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True + torch._C._jit_set_autocast_mode(False) + + MyModule = torch.jit.ScriptModule + MyFunction = torch.jit.script_method + MyStatic = torch.jit.script + from typing import List + + DTYPE = None + DEVICE = None + HEAD_SIZE = 64 + + if os.environ.get("RWKV_CUDA_ON") == "1": + from torch.utils.cpp_extension import load + + load( + name="wkv7s", + sources=[ + f"{current_path}/cuda/rwkv7_op.cpp", + f"{current_path}/cuda/rwkv7.cu", + ], + is_python_module=False, + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-D_N_={HEAD_SIZE}", + ], + ) + + class WKV_7(torch.autograd.Function): + @staticmethod + def forward(ctx, state, r, w, k, v, a, b): + with torch.no_grad(): + T, C = r.size() + H = C // HEAD_SIZE + N = HEAD_SIZE + assert HEAD_SIZE == C // H + assert all(x.dtype == DTYPE for x in [r, w, k, v, a, b]) + assert all(x.is_contiguous() for x in [r, w, k, v, a, b]) + y = torch.empty( + (T, C), + device=DEVICE, + dtype=r.dtype, + requires_grad=False, + memory_format=torch.contiguous_format, + ) + + if DTYPE == torch.float16: + torch.ops.wkv7s.forward_fp16( + 1, T, C, H, state, r, w, k, v, a, b, y + ) + elif DTYPE == torch.bfloat16: + torch.ops.wkv7s.forward_bf16( + 1, T, C, H, state, r, w, k, v, a, b, y + ) + elif DTYPE == torch.float32: + torch.ops.wkv7s.forward_fp32( + 1, T, C, H, state, r, w, k, v, a, b, y + ) + + return y + + def RWKV7_OP(state, r, w, k, v, a, b): + return WKV_7.apply(state, r, w, k, v, a, b) + + ######################################################################################################## + + class RWKV_x070(MyModule): + def __init__(self, model, strategy): + global DTYPE, DEVICE + super().__init__() + self.eval() + args = types.SimpleNamespace() + self.args = args + args.MODEL_NAME = model + + print(f"Loading {model} ({strategy})\n") + + ss = strategy.split(" ") + DEVICE = ss[0] + if ss[1] == "fp16": + DTYPE = torch.half + elif ss[1] == "fp32": + DTYPE = torch.float32 + elif ss[1] == "bf16": + DTYPE = torch.bfloat16 + else: + assert ( + False + ), "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16" + + self.z = torch.load(args.MODEL_NAME + ".pth", map_location=DEVICE) + z = self.z + + self.n_head, self.head_size = z["blocks.0.att.r_k"].shape + args.head_size = self.head_size + args.vocab_size, args.n_embd = z["emb.weight"].shape + + args.n_layer = 0 + keys = list(z.keys()) + for k in keys: + layer_id = int(k.split(".")[1]) if ("blocks." in k) else 0 + args.n_layer = max(args.n_layer, layer_id + 1) + if ( + "key.weight" in k + or "value.weight" in k + or "receptance.weight" in k + or "output.weight" in k + or "head.weight" in k + ): + z[k] = z[k].t() + z[k] = z[k].squeeze().to(dtype=DTYPE) + if k.endswith("att.r_k"): + z[k] = z[k].flatten() + self.n_embd = args.n_embd + self.n_layer = args.n_layer + + z["emb.weight"] = F.layer_norm( + z["emb.weight"], + (args.n_embd,), + weight=z["blocks.0.ln0.weight"], + bias=z["blocks.0.ln0.bias"], + ) + torch.cuda.empty_cache() + z["blocks.0.att.v0"] = z["blocks.0.att.a0"] # actually ignored + z["blocks.0.att.v1"] = z["blocks.0.att.a1"] # actually ignored + z["blocks.0.att.v2"] = z["blocks.0.att.a2"] # actually ignored + + def forward(self, idx, state, full_output=False): + if state == None: + state = [None for _ in range(self.args.n_layer * 3)] + for i in range( + self.args.n_layer + ): # state: 0=att_x_prev 1=att_kv 2=ffn_x_prev + state[i * 3 + 0] = torch.zeros( + self.args.n_embd, + dtype=DTYPE, + requires_grad=False, + device=DEVICE, + ) + state[i * 3 + 1] = torch.zeros( + ( + self.args.n_embd // self.args.head_size, + self.args.head_size, + self.args.head_size, + ), + dtype=torch.float, + requires_grad=False, + device=DEVICE, + ) + state[i * 3 + 2] = torch.zeros( + self.args.n_embd, + dtype=DTYPE, + requires_grad=False, + device=DEVICE, + ) + + if type(idx) is list: + if len(idx) > 1: + return self.forward_seq(idx, state, full_output) + else: + return self.forward_one(idx[0], state) + else: + return self.forward_one(idx, state) + + @MyFunction + def forward_one(self, idx: int, state: List[torch.Tensor]): + with torch.no_grad(): + z = self.z + x = z["emb.weight"][idx] + + v_first = torch.empty_like(x) + for i in range(self.n_layer): + bbb = f"blocks.{i}." + att = f"blocks.{i}.att." + ffn = f"blocks.{i}.ffn." + + xx = F.layer_norm( + x, + (self.n_embd,), + weight=z[bbb + "ln1.weight"], + bias=z[bbb + "ln1.bias"], + ) + + xx, state[i * 3 + 0], state[i * 3 + 1], v_first = ( + RWKV_x070_TMix_one( + i, + self.n_head, + self.head_size, + xx, + state[i * 3 + 0], + v_first, + state[i * 3 + 1], + z[att + "x_r"], + z[att + "x_w"], + z[att + "x_k"], + z[att + "x_v"], + z[att + "x_a"], + z[att + "x_g"], + z[att + "w0"], + z[att + "w1"], + z[att + "w2"], + z[att + "a0"], + z[att + "a1"], + z[att + "a2"], + z[att + "v0"], + z[att + "v1"], + z[att + "v2"], + z[att + "g1"], + z[att + "g2"], + z[att + "k_k"], + z[att + "k_a"], + z[att + "r_k"], + z[att + "receptance.weight"], + z[att + "key.weight"], + z[att + "value.weight"], + z[att + "output.weight"], + z[att + "ln_x.weight"], + z[att + "ln_x.bias"], + ) + ) + x = x + xx + + xx = F.layer_norm( + x, + (self.n_embd,), + weight=z[bbb + "ln2.weight"], + bias=z[bbb + "ln2.bias"], + ) + + xx, state[i * 3 + 2] = RWKV_x070_CMix_one( + xx, + state[i * 3 + 2], + z[ffn + "x_k"], + z[ffn + "key.weight"], + z[ffn + "value.weight"], + ) + x = x + xx + + # if math.isnan(torch.min(x).item()): print(idx, i) + + x = F.layer_norm( + x, (self.n_embd,), weight=z["ln_out.weight"], bias=z["ln_out.bias"] + ) + x = x @ z["head.weight"] + return x, state + + @MyFunction + def forward_seq( + self, idx: List[int], state: List[torch.Tensor], full_output: bool = False + ): + with torch.no_grad(): + z = self.z + x = z["emb.weight"][idx] + + v_first = torch.empty_like(x) + for i in range(self.n_layer): + bbb = f"blocks.{i}." + att = f"blocks.{i}.att." + ffn = f"blocks.{i}.ffn." + + xx = F.layer_norm( + x, + (self.n_embd,), + weight=z[bbb + "ln1.weight"], + bias=z[bbb + "ln1.bias"], + ) + + xx, state[i * 3 + 0], state[i * 3 + 1], v_first = ( + RWKV_x070_TMix_seq( + i, + self.n_head, + self.head_size, + xx, + state[i * 3 + 0], + v_first, + state[i * 3 + 1], + z[att + "x_r"], + z[att + "x_w"], + z[att + "x_k"], + z[att + "x_v"], + z[att + "x_a"], + z[att + "x_g"], + z[att + "w0"], + z[att + "w1"], + z[att + "w2"], + z[att + "a0"], + z[att + "a1"], + z[att + "a2"], + z[att + "v0"], + z[att + "v1"], + z[att + "v2"], + z[att + "g1"], + z[att + "g2"], + z[att + "k_k"], + z[att + "k_a"], + z[att + "r_k"], + z[att + "receptance.weight"], + z[att + "key.weight"], + z[att + "value.weight"], + z[att + "output.weight"], + z[att + "ln_x.weight"], + z[att + "ln_x.bias"], + ) + ) + x = x + xx + + xx = F.layer_norm( + x, + (self.n_embd,), + weight=z[bbb + "ln2.weight"], + bias=z[bbb + "ln2.bias"], + ) + + xx, state[i * 3 + 2] = RWKV_x070_CMix_seq( + xx, + state[i * 3 + 2], + z[ffn + "x_k"], + z[ffn + "key.weight"], + z[ffn + "value.weight"], + ) + x = x + xx + + if not full_output: + x = x[-1, :] + x = F.layer_norm( + x, (self.n_embd,), weight=z["ln_out.weight"], bias=z["ln_out.bias"] + ) + x = x @ z["head.weight"] + return x, state + + ######################################################################################################## + + @MyStatic + def RWKV_x070_TMix_one( + layer_id: int, + H: int, + N: int, + x, + x_prev, + v_first, + state, + x_r, + x_w, + x_k, + x_v, + x_a, + x_g, + w0, + w1, + w2, + a0, + a1, + a2, + v0, + v1, + v2, + g1, + g2, + k_k, + k_a, + r_k, + R_, + K_, + V_, + O_, + ln_w, + ln_b, + ): + xx = x_prev - x + xr, xw, xk, xv, xa, xg = ( + x + xx * x_r, + x + xx * x_w, + x + xx * x_k, + x + xx * x_v, + x + xx * x_a, + x + xx * x_g, + ) + + r = xr @ R_ + w = torch.tanh(xw @ w1) @ w2 + k = xk @ K_ + v = xv @ V_ + a = torch.sigmoid(a0 + (xa @ a1) @ a2) + g = torch.sigmoid(xg @ g1) @ g2 + + kk = torch.nn.functional.normalize((k * k_k).view(H, N), dim=-1, p=2.0).view( + H * N + ) + k = k * (1 + (a - 1) * k_a) + if layer_id == 0: + v_first = v + else: + v = v + (v_first - v) * torch.sigmoid(v0 + (xv @ v1) @ v2) + w = torch.exp( + -0.606531 * torch.sigmoid((w0 + w).float()) + ) # 0.606531 = exp(-0.5) + + vk = v.view(H, N, 1) @ k.view(H, 1, N) + ab = (-kk).view(H, N, 1) @ (kk * a).view(H, 1, N) + state = state * w.view(H, 1, N) + state @ ab.float() + vk.float() + xx = state.to(dtype=x.dtype) @ r.view(H, N, 1) + + xx = torch.nn.functional.group_norm( + xx.view(1, H * N), num_groups=H, weight=ln_w, bias=ln_b, eps=64e-5 + ).view(H * N) + xx = xx + ( + (r * k * r_k).view(H, N).sum(dim=-1, keepdim=True) * v.view(H, N) + ).view(H * N) + return (xx * g) @ O_, x, state, v_first + + if os.environ.get("RWKV_CUDA_ON") == "1": + + @MyStatic + def RWKV_x070_TMix_seq( + layer_id: int, + H: int, + N: int, + x, + x_prev, + v_first, + state, + x_r, + x_w, + x_k, + x_v, + x_a, + x_g, + w0, + w1, + w2, + a0, + a1, + a2, + v0, + v1, + v2, + g1, + g2, + k_k, + k_a, + r_k, + R_, + K_, + V_, + O_, + ln_w, + ln_b, + ): + T = x.shape[0] + xx = torch.cat((x_prev.unsqueeze(0), x[:-1, :])) - x + xr, xw, xk, xv, xa, xg = ( + x + xx * x_r, + x + xx * x_w, + x + xx * x_k, + x + xx * x_v, + x + xx * x_a, + x + xx * x_g, + ) + + r = xr @ R_ + w = torch.tanh(xw @ w1) @ w2 + k = xk @ K_ + v = xv @ V_ + a = torch.sigmoid(a0 + (xa @ a1) @ a2) + g = torch.sigmoid(xg @ g1) @ g2 + + kk = torch.nn.functional.normalize( + (k * k_k).view(T, H, N), dim=-1, p=2.0 + ).view(T, H * N) + k = k * (1 + (a - 1) * k_a) + if layer_id == 0: + v_first = v + else: + v = v + (v_first - v) * torch.sigmoid(v0 + (xv @ v1) @ v2) + + w = -torch.nn.functional.softplus(-(w0 + w)) - 0.5 + xx = RWKV7_OP(state, r, w, k, v, -kk, kk * a) + + xx = torch.nn.functional.group_norm( + xx.view(T, H * N), num_groups=H, weight=ln_w, bias=ln_b, eps=64e-5 + ).view(T, H * N) + xx = xx + ( + (r * k * r_k).view(T, H, N).sum(dim=-1, keepdim=True) * v.view(T, H, N) + ).view(T, H * N) + return (xx * g) @ O_, x[-1, :], state, v_first + + else: + + @MyStatic + def RWKV_x070_TMix_seq( + layer_id: int, + H: int, + N: int, + x, + x_prev, + v_first, + state, + x_r, + x_w, + x_k, + x_v, + x_a, + x_g, + w0, + w1, + w2, + a0, + a1, + a2, + v0, + v1, + v2, + g1, + g2, + k_k, + k_a, + r_k, + R_, + K_, + V_, + O_, + ln_w, + ln_b, + ): + T = x.shape[0] + xx = torch.cat((x_prev.unsqueeze(0), x[:-1, :])) - x + xr, xw, xk, xv, xa, xg = ( + x + xx * x_r, + x + xx * x_w, + x + xx * x_k, + x + xx * x_v, + x + xx * x_a, + x + xx * x_g, + ) + + r = xr @ R_ + w = torch.tanh(xw @ w1) @ w2 + k = xk @ K_ + v = xv @ V_ + a = torch.sigmoid(a0 + (xa @ a1) @ a2) + g = torch.sigmoid(xg @ g1) @ g2 + + kk = torch.nn.functional.normalize( + (k * k_k).view(T, H, N), dim=-1, p=2.0 + ).view(T, H * N) + k = k * (1 + (a - 1) * k_a) + if layer_id == 0: + v_first = v + else: + v = v + (v_first - v) * torch.sigmoid(v0 + (xv @ v1) @ v2) + + w = torch.exp( + -0.606531 * torch.sigmoid((w0 + w).float()) + ) # 0.606531 = exp(-0.5) + for t in range(T): + r_, w_, k_, v_, kk_, a_ = r[t], w[t], k[t], v[t], kk[t], a[t] + vk = v_.view(H, N, 1) @ k_.view(H, 1, N) + ab = (-kk_).view(H, N, 1) @ (kk_ * a_).view(H, 1, N) + state = state * w_.view(H, 1, N) + state @ ab.float() + vk.float() + xx[t] = (state.to(dtype=x.dtype) @ r_.view(H, N, 1)).view(H * N) + + xx = torch.nn.functional.group_norm( + xx.view(T, H * N), num_groups=H, weight=ln_w, bias=ln_b, eps=64e-5 + ).view(T, H * N) + xx = xx + ( + (r * k * r_k).view(T, H, N).sum(dim=-1, keepdim=True) * v.view(T, H, N) + ).view(T, H * N) + return (xx * g) @ O_, x[-1, :], state, v_first + + ######################################################################################################## + + @MyStatic + def RWKV_x070_CMix_one(x, x_prev, x_k, K_, V_): + xx = x_prev - x + k = x + xx * x_k + k = torch.relu(k @ K_) ** 2 + return k @ V_, x + + @MyStatic + def RWKV_x070_CMix_seq(x, x_prev, x_k, K_, V_): + xx = torch.cat((x_prev.unsqueeze(0), x[:-1, :])) - x + k = x + xx * x_k + k = torch.relu(k @ K_) ** 2 + return k @ V_, x[-1, :] + + +######################################################################################################## + + +class RWKV(MyModule): + def __init__(self, model, strategy, verbose=True, convert_and_save_and_exit=None): + super().__init__() + if verbose: + prxxx = lambda *args, **kwargs: print(*args, **kwargs) + else: + prxxx = lambda *args, **kwargs: None + + STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$" + if not re.match(STRATEGY_REGEX, strategy): + raise ValueError( + "Invalid strategy. Please read https://pypi.org/project/rwkv/" + ) + + strategy = ("->".join([x.strip() for x in strategy.split("->")])).replace( + "->", " -> " + ) + self.args = types.SimpleNamespace() + args = self.args + args.MODEL_NAME = model + args.strategy_string = strategy + + # Rescale for fp16 mode: set x = x/2 every X layer (to avoid fp16 overflow) + try: + self.RESCALE_LAYER = int( + os.environ["RWKV_RESCALE_LAYER"] + ) # !!! NOTE: SEEMS YOU SHOULD SET IT TO 999 (disable) FOR RWKV-MUSIC MODELS !!! + except: + self.RESCALE_LAYER = 6 if "fp16" in strategy else 0 + prxxx( + f'RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]} RWKV_CUDA_ON {os.environ["RWKV_CUDA_ON"]} RESCALE_LAYER {self.RESCALE_LAYER}\n' + ) + + args.MODEL_NAME = args.MODEL_NAME.strip() + if not args.MODEL_NAME.endswith(".pth"): + args.MODEL_NAME += ".pth" + prxxx(f"Loading {args.MODEL_NAME} ...") + with torch.no_grad(): + self.w = torch.load( + args.MODEL_NAME, map_location="cpu" + ) # load model to CPU first + gc.collect() + w = self.w + + ALREADY_CONVERTED = False + if "_strategy" in w: + ALREADY_CONVERTED = True + assert ( + convert_and_save_and_exit == None + ) # you should only convert a raw model + prxxx( + f"Converted model: strategy {w['_strategy']}, version {w['_version']}\n" + ) + assert ( + w["_strategy"] == args.strategy_string + ) # if you are using a new strategy, re-convert the model + assert ( + float(w["_version"]) >= 0.7 + ) # sometimes you should re-convert using latest convert_model.py + assert ( + w["_rescale_layer"] == self.RESCALE_LAYER + ) # must use same RESCALE_LAYER to avoid mistakes + del w["_strategy"] + del w["_version"] + del w["_rescale_layer"] + + args.n_embd = w["emb.weight"].shape[1] + args.n_att = w["blocks.0.att.key.weight"].shape[ + 0 + ] # note: transposed matrix + args.n_ffn = w["blocks.0.ffn.key.weight"].shape[ + 0 + ] # note: transposed matrix + args.n_layer = 0 + keys = list(w.keys()) + self.version = 4 + for x in keys: + layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0 + args.n_layer = max(args.n_layer, layer_id + 1) + if "ln_x" in x: + self.version = max(5, self.version) + if "gate.weight" in x: + self.version = max(5.1, self.version) + if int(self.version) == 5 and "att.time_decay" in x: + args.n_head = w[x].shape[0] + if len(w[x].shape) > 1: + if w[x].shape[1] > 1: + self.version = max(5.2, self.version) + if "time_maa" in x: + self.version = max(6, self.version) + if int(self.version) == 6 and "time_faaaa" in x: + args.n_head = w[x].shape[0] + prxxx(f"Model detected: v{self.version:.1f}") + + ####################### Compute strategy + + s = [x.strip().split(" ") for x in strategy.split("->")] + plan = [0] * len(s) + stream_i = -1 + stream_count = 0 + to_allocate = args.n_layer + 1 + allocated = 0 + free_slots = 0 + for i in range(len(s)): + si = s[i] + si1 = si[1] + if si1.startswith("fp32"): + si[1] = [torch.float] + elif si1.startswith("fp16"): + si[1] = [torch.float16] + elif si1.startswith("bf16"): + si[1] = [torch.bfloat16] + if si1.endswith("i8"): + si[1] += [torch.uint8] + else: + si[1] += [si[1][0]] + if len(si) > 2: + ss = si[2] + assert ss.startswith("*") + if ss.endswith("+"): + plan[i] = int(ss[1:-1]) + stream_i = i + else: + plan[i] = int(ss[1:]) + allocated += plan[i] + if allocated >= to_allocate: + plan[i] += to_allocate - allocated + break + else: + free_slots += 1 + if stream_i < 0: + if free_slots > 0 and to_allocate > allocated: + for i in range(len(s)): + if plan[i] == 0: + plan[i] = (to_allocate - allocated) // free_slots + allocated += plan[i] + free_slots -= 1 + if to_allocate > allocated: + plan[len(s) - 1] += to_allocate - allocated + else: + if to_allocate > allocated: + stream_count = to_allocate - allocated + plan[stream_i] += stream_count + prxxx(f"Strategy: (total {args.n_layer}+1={args.n_layer+1} layers)") + for i in range(len(s)): + ss = s[i] + if i != stream_i: + prxxx( + f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]} layers' + ) + else: + prxxx( + f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]-stream_count} layers, stream {stream_count} layers' + ) + plan[i] += 0 if i == 0 else plan[i - 1] + self.strategy = [None] * (args.n_layer + 1) + strategy = self.strategy + for n in range(args.n_layer + 1): + for i in range(len(s)): + if n < plan[i]: + strategy[n] = types.SimpleNamespace() + strategy[n].device = s[i][0] + strategy[n].atype = s[i][1][0] + strategy[n].wtype = s[i][1][1] + strategy[n].stream = False + if strategy[n].device == "dml": + strategy[n].device = torch_directml.device() + if i == stream_i and n >= (plan[i] - stream_count): + strategy[n].stream = True + break + prxxx( + f"{n}-{strategy[n].device}-{str(strategy[n].atype).replace('torch.','')}-{str(strategy[n].wtype).replace('torch.','')}{'-stream' if strategy[n].stream else ''}", + end=" ", + ) + prxxx() + + ####################### Load weights to self.w + + if not ALREADY_CONVERTED: + try: # precompute embedding + w["emb.weight"] = F.layer_norm( + w["emb.weight"], + (args.n_embd,), + weight=w["blocks.0.ln0.weight"], + bias=w["blocks.0.ln0.bias"], + ) + except: + w["emb.weight"] = F.layer_norm( + w["emb.weight"].float(), + (args.n_embd,), + weight=w["blocks.0.ln0.weight"].float(), + bias=w["blocks.0.ln0.bias"].float(), + ) + del w["blocks.0.ln0.weight"] + del w["blocks.0.ln0.bias"] + + print_need_newline = False + + REAL_TIME_FIRST = False + args.time_state = False + for x in list(w.keys()): + if ".time_faaaa" in x: + REAL_TIME_FIRST = True + if ".time_state" in x: + args.time_state = True + if REAL_TIME_FIRST: + w = { + ( + k.replace(".time_faaaa", ".time_first") + if ".time_faaaa" in k + else k + ): v + for k, v in w.items() + } + self.w = w + + keys = list(w.keys()) + for x in keys: + w[x].requires_grad = False + layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0 + if ("ln_out." in x) or ("head." in x): + layer_id = args.n_layer + dd = strategy[layer_id] + DEVICE = dd.device + ATYPE = dd.atype + WTYPE = dd.wtype + + if not ALREADY_CONVERTED: + if self.RESCALE_LAYER > 0: + if "att.output.weight" in x: + w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) + if "ffn.value.weight" in x: + w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) + + if ".time_" in x: + w[x] = w[x].squeeze() + if ( + "key.weight" in x + or "value.weight" in x + or "receptance.weight" in x + or "gate.weight" in x + or "output.weight" in x + or "head.weight" in x + ): + w[x] = w[x].t() + + if ".time_decay" in x and "_w" not in x: # need fp32 for this + if self.version == 4: + w[x] = -torch.exp(w[x].float()) + elif int(self.version) == 5: + w[x] = torch.exp(-torch.exp(w[x].float())).reshape(-1, 1, 1) + if self.version == 5.2: + w[x] = w[x].reshape(args.n_head, -1, 1) + elif self.version == 6.0: + w[x] = w[x].float().reshape(args.n_head, -1, 1) + elif ".time_first" in x: # need fp32 for this + if self.version == 4: + w[x] = w[x].float() + elif int(self.version) in [5, 6]: + if REAL_TIME_FIRST: + w[x] = w[x].float().reshape(-1, 1, 1) + else: + w[x] = torch.exp(w[x].float()).reshape(-1, 1, 1) + if self.version in [5.2, 6.0]: + w[x] = w[x].reshape(args.n_head, -1, 1) + elif ".ln_x" in x: # need fp32 for group_norm + w[x] = w[x].float() + else: + if ( + (len(w[x].shape) == 2) + and ("emb" not in x) + and ("_w1" not in x) + and ("_w2" not in x) + ): + if WTYPE != torch.uint8: + w[x] = w[x].to(dtype=WTYPE) + else: + w[x] = w[x].float() + + if w[x].shape[0] > w[x].shape[1]: + w[x + "_my"] = torch.amin(w[x], dim=1).unsqueeze(1) + w[x] = w[x] - w[x + "_my"] + w[x + "_mx"] = torch.amin(w[x], dim=0) + w[x] = w[x] - w[x + "_mx"] + w[x + "_rx"] = torch.amax(w[x], dim=0) + w[x] = w[x] / w[x + "_rx"] + w[x + "_ry"] = torch.amax(w[x], dim=1).unsqueeze(1) + w[x] = w[x] / w[x + "_ry"] + else: + w[x + "_mx"] = torch.amin(w[x], dim=0) + w[x] = w[x] - w[x + "_mx"] + w[x + "_my"] = torch.amin(w[x], dim=1).unsqueeze(1) + w[x] = w[x] - w[x + "_my"] + w[x + "_rx"] = torch.amax(w[x], dim=0) + w[x] = w[x] / w[x + "_rx"] + w[x + "_ry"] = torch.amax(w[x], dim=1).unsqueeze(1) + w[x] = w[x] / w[x + "_ry"] + + w[x] = torch.clip( + torch.floor(w[x] * 256), min=0, max=255 + ).to(dtype=torch.uint8) + w[x + "_mx"] = w[x + "_mx"].to(dtype=ATYPE).contiguous() + w[x + "_rx"] = ( + (w[x + "_rx"] / 16).to(dtype=ATYPE).contiguous() + ) + w[x + "_my"] = w[x + "_my"].to(dtype=ATYPE).contiguous() + w[x + "_ry"] = ( + (w[x + "_ry"] / 16).to(dtype=ATYPE).contiguous() + ) + else: + w[x] = w[x].to(dtype=ATYPE) + + if convert_and_save_and_exit == None: + if "emb." in x: + w[x] = w[x].contiguous() + elif (dd.stream) and ( + x.endswith("key.weight") + or x.endswith("value.weight") + or x.endswith("receptance.weight") + or x.endswith("output.weight") + ): + try: + w[x] = ( + w[x].contiguous().pin_memory() + ) # if you see "CUDA error: out of memory" here, that's out of CPU RAM, not VRAM. Get more RAM :) + except: + print( + "Note: You are running out of RAM. Get more CPU RAM. Now this will run much slower." + ) + elif DEVICE != "cpu": + w[x] = w[x].to(device=DEVICE).contiguous() + + if (dd.stream) or (DEVICE != "cpu"): + try: + w[x + "_mx"] = w[x + "_mx"].to(device=DEVICE).contiguous() + w[x + "_rx"] = w[x + "_rx"].to(device=DEVICE).contiguous() + w[x + "_my"] = w[x + "_my"].to(device=DEVICE).contiguous() + w[x + "_ry"] = w[x + "_ry"].to(device=DEVICE).contiguous() + except: + pass + + if "ffn.value.weight" in x: + gc.collect() + if "cuda" in args.strategy_string: + torch.cuda.empty_cache() + + shape = [i for i in w[x].shape if i != 1] + if len(shape) > 2: + shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} {str(shape[2]).rjust(5)}" + elif len(shape) > 1: + shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} " + else: + shape = f" {str(shape[0]).rjust(5)} " + if layer_id == 0 or layer_id >= args.n_layer - 1: + if print_need_newline: + prxxx("\n", end="") + print_need_newline = False + dt = str(w[x].dtype).replace("torch.", "") + dt = ( + dt.replace("float32", "f32") + .replace("bfloat16", "bf16") + .replace("float16", "f16") + .replace("uint8", "i8") + ) + prxxx( + x.ljust(32), + dt.rjust(4), + str(w[x].device).rjust(8), + shape, + " (pinned)" if w[x].is_pinned() else "", + ) + else: + print_need_newline = True + prxxx(".", end="", flush=True) + + if convert_and_save_and_exit: + w["_strategy"] = args.strategy_string + w["_rescale_layer"] = self.RESCALE_LAYER + w["_version"] = "0.7" + if not convert_and_save_and_exit.endswith(".pth"): + convert_and_save_and_exit += ".pth" + prxxx(f"Saving to {convert_and_save_and_exit}...") + torch.save(w, convert_and_save_and_exit) + prxxx(f"Converted and saved. Now this will exit.") + exit(0) + + if self.version == 5.2 and os.environ["RWKV_CUDA_ON"] == "1": + HEAD_SIZE = args.n_att // args.n_head + rwkv5 = load( + name="rwkv5", + sources=[ + f"{current_path}/cuda/rwkv5_op.cpp", + f"{current_path}/cuda/rwkv5.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3" if os.name != "nt" else "", + "--extra-device-vectorization", + f"-D_N_={HEAD_SIZE}", + ], + ) + + class RWKV_5(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, H, state, r, k, v, w, u): + with torch.no_grad(): + assert HEAD_SIZE == C // H + ctx.B = B + ctx.T = T + ctx.C = C + ctx.H = H + assert state.dtype == torch.float32 + assert w.dtype == torch.float32 + assert r.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert w.is_contiguous() + assert u.is_contiguous() + assert state.is_contiguous() + + y = torch.empty( + (B, T, C), + device=w.device, + dtype=r.dtype, + memory_format=torch.contiguous_format, + ) + if r.dtype == torch.bfloat16: + rwkv5.forward_bf16(B, T, C, H, state, r, k, v, w, u, y) + elif r.dtype == torch.float16: + rwkv5.forward_fp16(B, T, C, H, state, r, k, v, w, u, y) + elif r.dtype == torch.float32: + rwkv5.forward_fp32(B, T, C, H, state, r, k, v, w, u, y) + return y, state + + self.RWKV_5 = RWKV_5 + + if self.version == 6.0 and os.environ["RWKV_CUDA_ON"] == "1": + HEAD_SIZE = args.n_att // args.n_head + rwkv6 = load( + name="rwkv6", + sources=[ + f"{current_path}/cuda/rwkv6_op.cpp", + f"{current_path}/cuda/rwkv6.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3" if os.name != "nt" else "", + "--extra-device-vectorization", + f"-D_N_={HEAD_SIZE}", + f"-D_T_={4096}", + ], + ) + + class RWKV_6(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, H, state, r, k, v, w, u): + with torch.no_grad(): + assert HEAD_SIZE == C // H + ctx.B = B + ctx.T = T + ctx.C = C + ctx.H = H + assert state.dtype == torch.float32 + assert w.dtype == torch.float32 + assert r.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert w.is_contiguous() + assert u.is_contiguous() + eew = torch.exp(-torch.exp(w.float())).contiguous() + + y = torch.empty( + (B, T, C), + device=w.device, + dtype=r.dtype, + memory_format=torch.contiguous_format, + ) + if r.dtype == torch.bfloat16: + rwkv6.forward_bf16( + B, T, C, H, state, r, k, v, eew, u, y + ) + elif r.dtype == torch.float16: + rwkv6.forward_fp16( + B, T, C, H, state, r, k, v, eew, u, y + ) + elif r.dtype == torch.float32: + rwkv6.forward_fp32( + B, T, C, H, state, r, k, v, eew, u, y + ) + return y, state + + self.RWKV_6 = RWKV_6 + + gc.collect() + if "cuda" in args.strategy_string: + torch.cuda.empty_cache() + + def RUN_RWKV_5(self, B, T, C, H, state, r, k, v, w, u): + return self.RWKV_5.apply(B, T, C, H, state, r, k, v, w, u) + + def RUN_RWKV_6(self, B, T, C, H, state, r, k, v, w, u): + return self.RWKV_6.apply(B, T, C, H, state, r, k, v, w, u) + + ######################################################################################################## + + @MyFunction + def ffn_one( + self, + x, + sx, + ln_w, + ln_b, + k_mix, + r_mix, + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2 + out = r * matmul(vx, vw, vmx, vrx, vmy, vry) + return x + out, xx + + @MyFunction + def ffn_seq( + self, + x, + sx, + ln_w, + ln_b, + k_mix, + r_mix, + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2 + out = r * matmul(vx, vw, vmx, vrx, vmy, vry) + return x + out, xx[-1, :] + + @MyFunction + def ffn_one_v6( + self, + x, + sx, + ln_w, + ln_b, + k_maa, + r_maa, + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = sx - xx + kx = xx + sx * k_maa + rx = xx + sx * r_maa + + r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2 + out = r * matmul(vx, vw, vmx, vrx, vmy, vry) + return x + out, xx + + @MyFunction + def ffn_seq_v6( + self, + x, + sx, + ln_w, + ln_b, + k_maa, + r_maa, + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + sx = sx - xx + kx = xx + sx * k_maa + rx = xx + sx * r_maa + + r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2 + out = r * matmul(vx, vw, vmx, vrx, vmy, vry) + return x + out, xx[-1, :] + + ######################################################################################################## + + @MyFunction + def att_one( + self, + x, + sx, + aa, + bb, + pp, + ln_w, + ln_b, + k_mix, + v_mix, + r_mix, + t_decay, + t_first, + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + + ww = t_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + + out = matmul(r * wkv, ow, omx, orx, omy, ory) + return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p + + @MyFunction + def att_seq( + self, + x, + sx, + aa, + bb, + pp, + ln_w, + ln_b, + k_mix, + v_mix, + r_mix, + t_decay, + t_first, + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + + T = x.shape[0] + for t in range(T): + kk = k[t] + vv = v[t] + ww = t_first + kk + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, kk) + e1 = torch.exp(ww - p) + e2 = torch.exp(kk - p) + aa = e1 * aa + e2 * vv + bb = e1 * bb + e2 + pp = p + out = matmul(r * sx, ow, omx, orx, omy, ory) + return x + out, xx[-1, :], aa, bb, pp + + ######################################################################################################## + + @MyFunction + def att_one_v5( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + t_decay, + t_first, + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + H = t_decay.shape[0] + N = x.shape[-1] // H + + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, N) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, N, 1) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, N) + + a = matmul(k, v) + out = r @ (t_first * a + s) + s = a + t_decay * s + + out = out.flatten() + out = F.group_norm( + out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5 + ).squeeze(0) + out = out.to(dtype=x.dtype) + out = matmul(out, ow, omx, orx, omy, ory) + + return x + out, xx, s + + @MyFunction + def att_seq_v5( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + t_decay, + t_first, + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + w = t_decay.reshape(-1, 1) + u = t_first.reshape(-1, 1) + ws = w.pow(T).reshape(H, 1, 1) + ind = torch.arange(T - 1, -1, -1, device=w.device).unsqueeze(0).repeat(H, 1) + w = w.repeat(1, T).pow(ind) + wk = w.reshape(H, 1, T) + wb = wk.transpose(-2, -1).flip(1) + w = torch.cat([w[:, 1:], u], dim=1) + w = F.pad(w, (0, T)) + w = torch.tile(w, [T]) + w = w[:, :-T].reshape(-1, T, 2 * T - 1) + w = w[:, :, T - 1 :].reshape(H, T, T) + + r = ( + matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) + .view(T, H, N) + .transpose(0, 1) + ) + k = ( + matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + .view(T, H, N) + .permute(1, 2, 0) + ) + v = ( + matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + .view(T, H, N) + .transpose(0, 1) + ) + + out = ((r @ k) * w) @ v + (r @ s) * wb + s = ws * s + (k * wk) @ v + + out = out.transpose(0, 1).contiguous().reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5) + out = out.to(dtype=x.dtype) + out = matmul(out, ow, omx, orx, omy, ory) + + return x + out, xx[-1, :], s + + ######################################################################################################## + + @MyFunction + def att_one_v5_1( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + g_mix, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + gx = xx * g_mix + sx * (1 - g_mix) + + H = t_decay.shape[0] + N = x.shape[-1] // H + + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, N) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, N, 1) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, N) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) + + a = matmul(k, v) + out = r @ (t_first * a + s) + s = a + t_decay * s + + out = out.flatten() + out = F.group_norm( + out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5 + ).squeeze(0) + out = out.to(dtype=x.dtype) * g + out = matmul(out, ow, omx, orx, omy, ory) + + return x + out, xx, s + + @MyFunction + def att_seq_v5_1( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + g_mix, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + gx = xx * g_mix + sx * (1 - g_mix) + + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + w = t_decay.reshape(-1, 1) + u = t_first.reshape(-1, 1) + ws = w.pow(T).reshape(H, 1, 1) + ind = torch.arange(T - 1, -1, -1, device=w.device).unsqueeze(0).repeat(H, 1) + w = w.repeat(1, T).pow(ind) + wk = w.reshape(H, 1, T) + wb = wk.transpose(-2, -1).flip(1) + w = torch.cat([w[:, 1:], u], dim=1) + w = F.pad(w, (0, T)) + w = torch.tile(w, [T]) + w = w[:, :-T].reshape(-1, T, 2 * T - 1) + w = w[:, :, T - 1 :].reshape(H, T, T) + + r = ( + matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) + .view(T, H, N) + .transpose(0, 1) + ) + k = ( + matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + .view(T, H, N) + .permute(1, 2, 0) + ) + v = ( + matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + .view(T, H, N) + .transpose(0, 1) + ) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) + + out = ((r @ k) * w) @ v + (r @ s) * wb + s = ws * s + (k * wk) @ v + + out = out.transpose(0, 1).contiguous().reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5) + out = out.to(dtype=x.dtype) * g + out = matmul(out, ow, omx, orx, omy, ory) + + return x + out, xx[-1, :], s + + ######################################################################################################## + + @MyFunction + def att_seq_v5_2( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + g_mix, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + gx = xx * g_mix + sx * (1 - g_mix) + + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + r = ( + matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) + .view(T, H, N) + .transpose(0, 1) + ) + k = ( + matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + .view(T, H, N) + .permute(1, 2, 0) + ) + v = ( + matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + .view(T, H, N) + .transpose(0, 1) + ) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) + + out = torch.empty((T, H, N), dtype=r.dtype, device=r.device) + for t in range(T): + rt = r[:, t : t + 1, :] + kt = k[:, :, t : t + 1] + vt = v[:, t : t + 1, :] + at = matmul(kt, vt) + out[t] = (rt @ (t_first * at + s)).squeeze(1) + s = at + t_decay * s + + out = out.reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5) + out = out.to(dtype=x.dtype) * g + out = matmul(out, ow, omx, orx, omy, ory) + + return x + out, xx[-1, :], s + + ######################################################################################################## + + @MyFunction + def att_one_v6_0( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + x_maa, + w_maa, + k_maa, + v_maa, + r_maa, + g_maa, + tm_w1, + tm_w2, + td_w1, + td_w2, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + + sx = sx - xx + xxx = xx + sx * x_maa + xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1) + xxx = torch.bmm(xxx, tm_w2).view(5, -1) + mw, mk, mv, mr, mg = xxx.unbind(dim=0) + + wx = xx + sx * (w_maa + mw) + kx = xx + sx * (k_maa + mk) + vx = xx + sx * (v_maa + mv) + rx = xx + sx * (r_maa + mr) + gx = xx + sx * (g_maa + mg) + + H = t_decay.shape[0] + N = x.shape[-1] // H + + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, N) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, N, 1) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, N) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) + + w = t_decay + (torch.tanh(wx @ td_w1) @ td_w2).float().view(H, N, 1) + w = torch.exp(-torch.exp(w.float())) + + a = matmul(k, v) + out = r @ (t_first * a + s) + s = a + w * s + + out = out.flatten() + out = F.group_norm( + out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5 + ).squeeze(0) + out = out.to(dtype=x.dtype) * g + out = matmul(out, ow, omx, orx, omy, ory) + + return x + out, xx, s + + @MyFunction + def att_seq_v6_0( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + x_maa, + w_maa, + k_maa, + v_maa, + r_maa, + g_maa, + tm_w1, + tm_w2, + td_w1, + td_w2, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) - xx + xxx = xx + sx * x_maa + xxx = torch.tanh(xxx @ tm_w1).view(T, 5, -1).transpose(0, 1) + xxx = torch.bmm(xxx, tm_w2).view(5, T, -1) + mw, mk, mv, mr, mg = xxx.unbind(dim=0) + + wx = xx + sx * (w_maa + mw) + kx = xx + sx * (k_maa + mk) + vx = xx + sx * (v_maa + mv) + rx = xx + sx * (r_maa + mr) + gx = xx + sx * (g_maa + mg) + + r = ( + matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) + .view(T, H, N) + .transpose(0, 1) + ) + k = ( + matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + .view(T, H, N) + .permute(1, 2, 0) + ) + v = ( + matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + .view(T, H, N) + .transpose(0, 1) + ) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) + + w = t_decay.view(1, H, N, 1) + (torch.tanh(wx @ td_w1) @ td_w2).float().view( + T, H, N, 1 + ) + w = torch.exp(-torch.exp(w.float())) + out = torch.empty((T, H, N), dtype=r.dtype, device=r.device) + for t in range(T): + rt = r[:, t : t + 1, :] + kt = k[:, :, t : t + 1] + vt = v[:, t : t + 1, :] + at = matmul(kt, vt) + out[t] = (rt @ (t_first * at + s)).squeeze(1) + s = at + w[t] * s + + out = out.reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5) + out = out.to(dtype=x.dtype) * g + out = matmul(out, ow, omx, orx, omy, ory) + + return x + out, xx[-1, :], s + + ######################################################################################################## + + if os.environ["RWKV_CUDA_ON"] == "1": + + @MyFunction + def cuda_att_seq( + self, + x, + sx, + aa, + bb, + pp, + ln_w, + ln_b, + k_mix, + v_mix, + r_mix, + t_decay, + t_first, + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + T, C = x.shape + xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp) + + out = matmul(r * y.to(x.dtype), ow, omx, orx, omy, ory) + return x + out, xx[-1, :], aa, bb, pp + + @MyFunction + def v5_2_before( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + g_mix, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + gx = xx * g_mix + sx * (1 - g_mix) + + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) + + return r, k, v, g, xx[-1, :], s.transpose(-1, -2).contiguous() + + @MyFunction + def v5_2_after( + self, t_decay, out, s, x, xxx, g, lx_w, lx_b, ow, omx, orx, omy, ory + ): + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + s = s.transpose(-1, -2) + out = out.reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5) + out = out.to(dtype=x.dtype) * g + out = matmul(out, ow, omx, orx, omy, ory) + + return x + out, xxx, s + + def cuda_att_seq_v5_2( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + g_mix, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + r, k, v, g, xxx, ss = self.v5_2_before( + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + g_mix, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ) + + out, s = self.RUN_RWKV_5( + 1, T, self.args.n_att, H, ss, r, k, v, w=t_decay, u=t_first + ) + + return self.v5_2_after( + t_decay, out, s, x, xxx, g, lx_w, lx_b, ow, omx, orx, omy, ory + ) + + @MyFunction + def v6_0_before( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + x_maa, + w_maa, + k_maa, + v_maa, + r_maa, + g_maa, + tm_w1, + tm_w2, + td_w1, + td_w2, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) - xx + xxx = xx + sx * x_maa + xxx = torch.tanh(xxx @ tm_w1).view(T, 5, -1).transpose(0, 1) + xxx = torch.bmm(xxx, tm_w2).view(5, T, -1) + mw, mk, mv, mr, mg = xxx.unbind(dim=0) + + wx = xx + sx * (w_maa + mw) + kx = xx + sx * (k_maa + mk) + vx = xx + sx * (v_maa + mv) + rx = xx + sx * (r_maa + mr) + gx = xx + sx * (g_maa + mg) + + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) + + w = t_decay.view(1, H, N, 1) + ( + torch.tanh(wx @ td_w1) @ td_w2 + ).float().view(T, H, N, 1) + + return r, k, v, g, w, xx[-1, :], s.transpose(-1, -2).contiguous() + + def cuda_att_seq_v6_0( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + x_maa, + w_maa, + k_maa, + v_maa, + r_maa, + g_maa, + tm_w1, + tm_w2, + td_w1, + td_w2, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + r, k, v, g, w, xxx, ss = self.v6_0_before( + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + x_maa, + w_maa, + k_maa, + v_maa, + r_maa, + g_maa, + tm_w1, + tm_w2, + td_w1, + td_w2, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ) + + out, s = self.RUN_RWKV_6( + 1, T, self.args.n_att, H, ss, r, k, v, w=w, u=t_first + ) + return self.v5_2_after( + t_decay, out, s, x, xxx, g, lx_w, lx_b, ow, omx, orx, omy, ory + ) + + ######################################################################################################## + + def forward(self, tokens, state, full_output=False): + with torch.no_grad(): + w = self.w + args = self.args + + if state == None: + if self.version == 4: + state = [None] * args.n_layer * 5 + for i in range( + args.n_layer + ): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + state[i * 5 + 0] = torch.zeros( + args.n_embd, dtype=atype, requires_grad=False, device=dev + ).contiguous() + state[i * 5 + 1] = torch.zeros( + args.n_att, + dtype=torch.float, + requires_grad=False, + device=dev, + ).contiguous() + state[i * 5 + 2] = torch.zeros( + args.n_att, + dtype=torch.float, + requires_grad=False, + device=dev, + ).contiguous() + state[i * 5 + 3] = ( + torch.zeros( + args.n_att, + dtype=torch.float, + requires_grad=False, + device=dev, + ).contiguous() + - 1e30 + ) + state[i * 5 + 4] = torch.zeros( + args.n_embd, dtype=atype, requires_grad=False, device=dev + ).contiguous() + elif int(self.version) in [5, 6]: + state = [None] * args.n_layer * 3 + for i in range(args.n_layer): # state: 0=att_xx 1=att_kv 2=ffn_xx + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + state[i * 3 + 0] = torch.zeros( + args.n_embd, dtype=atype, requires_grad=False, device=dev + ).contiguous() + if args.time_state: + state[i * 3 + 1] = ( + w[f"blocks.{i}.att.time_state"] + .transpose(1, 2) + .to(dtype=torch.float, device=dev) + .requires_grad_(False) + .contiguous() + ) + else: + state[i * 3 + 1] = torch.zeros( + ( + args.n_head, + args.n_att // args.n_head, + args.n_att // args.n_head, + ), + dtype=torch.float, + requires_grad=False, + device=dev, + ).contiguous() + state[i * 3 + 2] = torch.zeros( + args.n_embd, dtype=atype, requires_grad=False, device=dev + ).contiguous() + + seq_mode = len(tokens) > 1 + + x = w["emb.weight"][tokens if seq_mode else tokens[0]] + + for i in range(args.n_layer): + bbb = f"blocks.{i}." + att = f"blocks.{i}.att." + ffn = f"blocks.{i}.ffn." + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + wtype = dd.wtype + if seq_mode: + cuda_applicable = os.environ[ + "RWKV_CUDA_ON" + ] == "1" and "cuda" in str(dev) + if cuda_applicable: + ATT = self.cuda_att_seq + else: + ATT = self.att_seq + if self.version == 5: + ATT = self.att_seq_v5 + elif self.version == 5.1: + ATT = self.att_seq_v5_1 + elif self.version == 5.2: + ATT = self.att_seq_v5_2 + if cuda_applicable: + ATT = self.cuda_att_seq_v5_2 + elif self.version == 6.0: + ATT = self.att_seq_v6_0 + if cuda_applicable: + ATT = self.cuda_att_seq_v6_0 + FFN = self.ffn_seq + if self.version >= 6.0: + FFN = self.ffn_seq_v6 + else: + ATT = self.att_one + if self.version == 5: + ATT = self.att_one_v5 + elif self.version == 5.1: + ATT = self.att_one_v5_1 + elif self.version == 5.2: + ATT = self.att_one_v5_1 # same as v5.1 + elif self.version == 6.0: + ATT = self.att_one_v6_0 + FFN = self.ffn_one + if self.version >= 6.0: + FFN = self.ffn_one_v6 + + x = x.to(dtype=atype, device=dev) + + kw = w[f"{att}key.weight"] + vw = w[f"{att}value.weight"] + rw = w[f"{att}receptance.weight"] + ow = w[f"{att}output.weight"] + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + ow = ow.to(device=dev, non_blocking=True) + kmx = w[f"{att}key.weight_mx"] if wtype == torch.uint8 else x + krx = w[f"{att}key.weight_rx"] if wtype == torch.uint8 else x + kmy = w[f"{att}key.weight_my"] if wtype == torch.uint8 else x + kry = w[f"{att}key.weight_ry"] if wtype == torch.uint8 else x + vmx = w[f"{att}value.weight_mx"] if wtype == torch.uint8 else x + vrx = w[f"{att}value.weight_rx"] if wtype == torch.uint8 else x + vmy = w[f"{att}value.weight_my"] if wtype == torch.uint8 else x + vry = w[f"{att}value.weight_ry"] if wtype == torch.uint8 else x + rmx = w[f"{att}receptance.weight_mx"] if wtype == torch.uint8 else x + rrx = w[f"{att}receptance.weight_rx"] if wtype == torch.uint8 else x + rmy = w[f"{att}receptance.weight_my"] if wtype == torch.uint8 else x + rry = w[f"{att}receptance.weight_ry"] if wtype == torch.uint8 else x + omx = w[f"{att}output.weight_mx"] if wtype == torch.uint8 else x + orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x + omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x + ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x + if self.version in [5.1, 5.2, 6.0]: + gw = w[f"{att}gate.weight"] + if dd.stream: + gw = gw.to(device=dev, non_blocking=True) + gmx = w[f"{att}gate.weight_mx"] if wtype == torch.uint8 else x + grx = w[f"{att}gate.weight_rx"] if wtype == torch.uint8 else x + gmy = w[f"{att}gate.weight_my"] if wtype == torch.uint8 else x + gry = w[f"{att}gate.weight_ry"] if wtype == torch.uint8 else x + if self.version == 4: + ( + x, + state[i * 5 + 0], + state[i * 5 + 1], + state[i * 5 + 2], + state[i * 5 + 3], + ) = ATT( + x, + state[i * 5 + 0], + state[i * 5 + 1], + state[i * 5 + 2], + state[i * 5 + 3], + w[f"{bbb}ln1.weight"], + w[f"{bbb}ln1.bias"], + w[f"{att}time_mix_k"], + w[f"{att}time_mix_v"], + w[f"{att}time_mix_r"], + w[f"{att}time_decay"], + w[f"{att}time_first"], + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ) + elif self.version == 5: + x, state[i * 3 + 0], state[i * 3 + 1] = ATT( + x, + state[i * 3 + 0], + state[i * 3 + 1], + w[f"{bbb}ln1.weight"], + w[f"{bbb}ln1.bias"], + w[f"{att}ln_x.weight"], + w[f"{att}ln_x.bias"], + w[f"{att}time_mix_k"], + w[f"{att}time_mix_v"], + w[f"{att}time_mix_r"], + w[f"{att}time_decay"], + w[f"{att}time_first"], + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ) + elif self.version in [5.1, 5.2]: + x, state[i * 3 + 0], state[i * 3 + 1] = ATT( + x, + state[i * 3 + 0], + state[i * 3 + 1], + w[f"{bbb}ln1.weight"], + w[f"{bbb}ln1.bias"], + w[f"{att}ln_x.weight"], + w[f"{att}ln_x.bias"], + w[f"{att}time_mix_k"], + w[f"{att}time_mix_v"], + w[f"{att}time_mix_r"], + w[f"{att}time_mix_g"], + w[f"{att}time_decay"], + w[f"{att}time_first"], + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ) + elif self.version == 6.0: + x, state[i * 3 + 0], state[i * 3 + 1] = ATT( + x, + state[i * 3 + 0], + state[i * 3 + 1], + w[f"{bbb}ln1.weight"], + w[f"{bbb}ln1.bias"], + w[f"{att}ln_x.weight"], + w[f"{att}ln_x.bias"], + w[f"{att}time_maa_x"], + w[f"{att}time_maa_w"], + w[f"{att}time_maa_k"], + w[f"{att}time_maa_v"], + w[f"{att}time_maa_r"], + w[f"{att}time_maa_g"], + w[f"{att}time_maa_w1"], + w[f"{att}time_maa_w2"], + w[f"{att}time_decay_w1"], + w[f"{att}time_decay_w2"], + w[f"{att}time_decay"], + w[f"{att}time_first"], + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ) + if dd.stream: + del kw, vw, rw, ow + if self.version in [5.1, 5.2, 6.0]: + del gw + + kw = w[f"{ffn}key.weight"] + vw = w[f"{ffn}value.weight"] + rw = w[f"{ffn}receptance.weight"] + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + kmx = w[f"{ffn}key.weight_mx"] if wtype == torch.uint8 else x + krx = w[f"{ffn}key.weight_rx"] if wtype == torch.uint8 else x + kmy = w[f"{ffn}key.weight_my"] if wtype == torch.uint8 else x + kry = w[f"{ffn}key.weight_ry"] if wtype == torch.uint8 else x + vmx = w[f"{ffn}value.weight_mx"] if wtype == torch.uint8 else x + vrx = w[f"{ffn}value.weight_rx"] if wtype == torch.uint8 else x + vmy = w[f"{ffn}value.weight_my"] if wtype == torch.uint8 else x + vry = w[f"{ffn}value.weight_ry"] if wtype == torch.uint8 else x + rmx = w[f"{ffn}receptance.weight_mx"] if wtype == torch.uint8 else x + rrx = w[f"{ffn}receptance.weight_rx"] if wtype == torch.uint8 else x + rmy = w[f"{ffn}receptance.weight_my"] if wtype == torch.uint8 else x + rry = w[f"{ffn}receptance.weight_ry"] if wtype == torch.uint8 else x + if self.version == 4: + offset = i * 5 + 4 + elif int(self.version) in [5, 6]: + offset = i * 3 + 2 + if self.version < 6.0: + x, state[offset] = FFN( + x, + state[offset], + w[f"{bbb}ln2.weight"], + w[f"{bbb}ln2.bias"], + w[f"{ffn}time_mix_k"], + w[f"{ffn}time_mix_r"], + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ) + else: + x, state[offset] = FFN( + x, + state[offset], + w[f"{bbb}ln2.weight"], + w[f"{bbb}ln2.bias"], + w[f"{ffn}time_maa_k"], + w[f"{ffn}time_maa_r"], + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ) + if dd.stream: + del kw, vw, rw + + if self.RESCALE_LAYER > 0: + if (i + 1) % self.RESCALE_LAYER == 0: + x = x / 2 + + dd = self.strategy[args.n_layer] + x = x[-1, :] if (seq_mode and (not full_output)) else x + x = x.to(dtype=dd.atype, device=dd.device) + + x = F.layer_norm( + x, (args.n_embd,), weight=w["ln_out.weight"], bias=w["ln_out.bias"] + ) + if w["head.weight"].dtype != torch.uint8: + x = x @ w["head.weight"] + else: + if seq_mode and full_output: + x = mm8_seq( + x, + w["head.weight"], + w["head.weight_mx"], + w["head.weight_rx"], + w["head.weight_my"], + w["head.weight_ry"], + ) + else: + x = mm8_one( + x, + w["head.weight"], + w["head.weight_mx"], + w["head.weight_rx"], + w["head.weight_my"], + w["head.weight_ry"], + ) + + return x.float(), state + + +if os.environ.get("RWKV_V7_ON") == "1": + RWKV = RWKV_x070